Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0cceabfc
Unverified
Commit
0cceabfc
authored
Aug 03, 2020
by
Yiming Shi
Committed by
GitHub
Aug 03, 2020
Browse files
Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor
parents
17821c0d
39ee0ac9
Changes
339
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
990 additions
and
474 deletions
+990
-474
research/attention_ocr/python/datasets/fsns_test.py
research/attention_ocr/python/datasets/fsns_test.py
+1
-1
research/attention_ocr/python/datasets/testdata/fsns/download_data.py
...ention_ocr/python/datasets/testdata/fsns/download_data.py
+3
-2
research/attention_ocr/python/demo_inference.py
research/attention_ocr/python/demo_inference.py
+12
-11
research/attention_ocr/python/demo_inference_test.py
research/attention_ocr/python/demo_inference_test.py
+40
-39
research/attention_ocr/python/eval.py
research/attention_ocr/python/eval.py
+3
-3
research/attention_ocr/python/inception_preprocessing.py
research/attention_ocr/python/inception_preprocessing.py
+16
-16
research/attention_ocr/python/metrics.py
research/attention_ocr/python/metrics.py
+20
-18
research/attention_ocr/python/metrics_test.py
research/attention_ocr/python/metrics_test.py
+7
-7
research/attention_ocr/python/model.py
research/attention_ocr/python/model.py
+303
-129
research/attention_ocr/python/model_export.py
research/attention_ocr/python/model_export.py
+198
-0
research/attention_ocr/python/model_export_lib.py
research/attention_ocr/python/model_export_lib.py
+108
-0
research/attention_ocr/python/model_export_test.py
research/attention_ocr/python/model_export_test.py
+160
-0
research/attention_ocr/python/model_test.py
research/attention_ocr/python/model_test.py
+71
-49
research/attention_ocr/python/sequence_layers.py
research/attention_ocr/python/sequence_layers.py
+12
-12
research/attention_ocr/python/sequence_layers_test.py
research/attention_ocr/python/sequence_layers_test.py
+2
-2
research/attention_ocr/python/train.py
research/attention_ocr/python/train.py
+11
-11
research/attention_ocr/python/utils.py
research/attention_ocr/python/utils.py
+23
-6
research/autoencoder/AdditiveGaussianNoiseAutoencoderRunner.py
...rch/autoencoder/AdditiveGaussianNoiseAutoencoderRunner.py
+0
-58
research/autoencoder/AutoencoderRunner.py
research/autoencoder/AutoencoderRunner.py
+0
-55
research/autoencoder/MaskingNoiseAutoencoderRunner.py
research/autoencoder/MaskingNoiseAutoencoderRunner.py
+0
-55
No files found.
Too many changes to show.
To preserve performance only
339 of 339+
files are displayed.
Plain diff
Email patch
research/attention_ocr/python/datasets/fsns_test.py
View file @
0cceabfc
...
@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase):
...
@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase):
image_tf
,
label_tf
=
provider
.
get
([
'image'
,
'label'
])
image_tf
,
label_tf
=
provider
.
get
([
'image'
,
'label'
])
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
with
slim
.
queues
.
QueueRunners
(
sess
):
with
slim
.
queues
.
QueueRunners
(
sess
):
image_np
,
label_np
=
sess
.
run
([
image_tf
,
label_tf
])
image_np
,
label_np
=
sess
.
run
([
image_tf
,
label_tf
])
...
...
research/attention_ocr/python/datasets/testdata/fsns/download_data.py
View file @
0cceabfc
...
@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5
...
@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5
print
(
'Downloading %s ...'
%
URL
)
print
(
'Downloading %s ...'
%
URL
)
urllib
.
request
.
urlretrieve
(
URL
,
DST_ORIG
)
urllib
.
request
.
urlretrieve
(
URL
,
DST_ORIG
)
print
(
'Writing %d records from %s to %s ...'
%
(
KEEP_NUM_RECORDS
,
DST_ORIG
,
DST
))
print
(
'Writing %d records from %s to %s ...'
%
(
KEEP_NUM_RECORDS
,
DST_ORIG
,
DST
))
with
tf
.
io
.
TFRecordWriter
(
DST
)
as
writer
:
with
tf
.
io
.
TFRecordWriter
(
DST
)
as
writer
:
for
raw_record
in
itertools
.
islice
(
tf
.
python_io
.
tf_record_iterator
(
DST_ORIG
),
KEEP_NUM_RECORDS
):
for
raw_record
in
itertools
.
islice
(
tf
.
compat
.
v1
.
python_io
.
tf_record_iterator
(
DST_ORIG
),
KEEP_NUM_RECORDS
):
writer
.
write
(
raw_record
)
writer
.
write
(
raw_record
)
research/attention_ocr/python/demo_inference.py
View file @
0cceabfc
...
@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
...
@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
path
=
file_pattern
%
i
path
=
file_pattern
%
i
print
(
"Reading %s"
%
path
)
print
(
"Reading %s"
%
path
)
pil_image
=
PIL
.
Image
.
open
(
tf
.
gfile
.
GFile
(
path
,
'rb'
))
pil_image
=
PIL
.
Image
.
open
(
tf
.
io
.
gfile
.
GFile
(
path
,
'rb'
))
images_actual_data
[
i
,
...]
=
np
.
asarray
(
pil_image
)
images_actual_data
[
i
,
...]
=
np
.
asarray
(
pil_image
)
return
images_actual_data
return
images_actual_data
...
@@ -58,12 +58,13 @@ def create_model(batch_size, dataset_name):
...
@@ -58,12 +58,13 @@ def create_model(batch_size, dataset_name):
width
,
height
=
get_dataset_image_size
(
dataset_name
)
width
,
height
=
get_dataset_image_size
(
dataset_name
)
dataset
=
common_flags
.
create_dataset
(
split_name
=
FLAGS
.
split_name
)
dataset
=
common_flags
.
create_dataset
(
split_name
=
FLAGS
.
split_name
)
model
=
common_flags
.
create_model
(
model
=
common_flags
.
create_model
(
num_char_classes
=
dataset
.
num_char_classes
,
num_char_classes
=
dataset
.
num_char_classes
,
seq_length
=
dataset
.
max_sequence_length
,
seq_length
=
dataset
.
max_sequence_length
,
num_views
=
dataset
.
num_of_views
,
num_views
=
dataset
.
num_of_views
,
null_code
=
dataset
.
null_code
,
null_code
=
dataset
.
null_code
,
charset
=
dataset
.
charset
)
charset
=
dataset
.
charset
)
raw_images
=
tf
.
placeholder
(
tf
.
uint8
,
shape
=
[
batch_size
,
height
,
width
,
3
])
raw_images
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
uint8
,
shape
=
[
batch_size
,
height
,
width
,
3
])
images
=
tf
.
map_fn
(
data_provider
.
preprocess_image
,
raw_images
,
images
=
tf
.
map_fn
(
data_provider
.
preprocess_image
,
raw_images
,
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
endpoints
=
model
.
create_base
(
images
,
labels_one_hot
=
None
)
endpoints
=
model
.
create_base
(
images
,
labels_one_hot
=
None
)
...
@@ -76,9 +77,9 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
...
@@ -76,9 +77,9 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_data
=
load_images
(
image_path_pattern
,
batch_size
,
images_data
=
load_images
(
image_path_pattern
,
batch_size
,
dataset_name
)
dataset_name
)
session_creator
=
monitored_session
.
ChiefSessionCreator
(
session_creator
=
monitored_session
.
ChiefSessionCreator
(
checkpoint_filename_with_path
=
checkpoint
)
checkpoint_filename_with_path
=
checkpoint
)
with
monitored_session
.
MonitoredSession
(
with
monitored_session
.
MonitoredSession
(
session_creator
=
session_creator
)
as
sess
:
session_creator
=
session_creator
)
as
sess
:
predictions
=
sess
.
run
(
endpoints
.
predicted_text
,
predictions
=
sess
.
run
(
endpoints
.
predicted_text
,
feed_dict
=
{
images_placeholder
:
images_data
})
feed_dict
=
{
images_placeholder
:
images_data
})
return
[
pr_bytes
.
decode
(
'utf-8'
)
for
pr_bytes
in
predictions
.
tolist
()]
return
[
pr_bytes
.
decode
(
'utf-8'
)
for
pr_bytes
in
predictions
.
tolist
()]
...
@@ -87,10 +88,10 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
...
@@ -87,10 +88,10 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
def
main
(
_
):
def
main
(
_
):
print
(
"Predicted strings:"
)
print
(
"Predicted strings:"
)
predictions
=
run
(
FLAGS
.
checkpoint
,
FLAGS
.
batch_size
,
FLAGS
.
dataset_name
,
predictions
=
run
(
FLAGS
.
checkpoint
,
FLAGS
.
batch_size
,
FLAGS
.
dataset_name
,
FLAGS
.
image_path_pattern
)
FLAGS
.
image_path_pattern
)
for
line
in
predictions
:
for
line
in
predictions
:
print
(
line
)
print
(
line
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
tf
.
compat
.
v1
.
app
.
run
()
research/attention_ocr/python/demo_inference_test.py
View file @
0cceabfc
...
@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase):
...
@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase):
super
(
DemoInferenceTest
,
self
).
setUp
()
super
(
DemoInferenceTest
,
self
).
setUp
()
for
suffix
in
[
'.meta'
,
'.index'
,
'.data-00000-of-00001'
]:
for
suffix
in
[
'.meta'
,
'.index'
,
'.data-00000-of-00001'
]:
filename
=
_CHECKPOINT
+
suffix
filename
=
_CHECKPOINT
+
suffix
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
filename
),
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
filename
),
msg
=
'Missing checkpoint file %s. '
msg
=
'Missing checkpoint file %s. '
'Please download and extract it from %s'
%
'Please download and extract it from %s'
%
(
filename
,
_CHECKPOINT_URL
))
(
filename
,
_CHECKPOINT_URL
))
self
.
_batch_size
=
32
self
.
_batch_size
=
32
tf
.
flags
.
FLAGS
.
dataset_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'datasets/testdata/fsns'
)
tf
.
flags
.
FLAGS
.
dataset_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'datasets/testdata/fsns'
)
def
test_moving_variables_properly_loaded_from_a_checkpoint
(
self
):
def
test_moving_variables_properly_loaded_from_a_checkpoint
(
self
):
batch_size
=
32
batch_size
=
32
...
@@ -30,15 +31,15 @@ class DemoInferenceTest(tf.test.TestCase):
...
@@ -30,15 +31,15 @@ class DemoInferenceTest(tf.test.TestCase):
images_data
=
demo_inference
.
load_images
(
image_path_pattern
,
batch_size
,
images_data
=
demo_inference
.
load_images
(
image_path_pattern
,
batch_size
,
dataset_name
)
dataset_name
)
tensor_name
=
'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
tensor_name
=
'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf
=
tf
.
get_default_graph
().
get_tensor_by_name
(
moving_mean_tf
=
tf
.
compat
.
v1
.
get_default_graph
().
get_tensor_by_name
(
tensor_name
+
':0'
)
tensor_name
+
':0'
)
reader
=
tf
.
train
.
NewCheckpointReader
(
_CHECKPOINT
)
reader
=
tf
.
compat
.
v1
.
train
.
NewCheckpointReader
(
_CHECKPOINT
)
moving_mean_expected
=
reader
.
get_tensor
(
tensor_name
)
moving_mean_expected
=
reader
.
get_tensor
(
tensor_name
)
session_creator
=
monitored_session
.
ChiefSessionCreator
(
session_creator
=
monitored_session
.
ChiefSessionCreator
(
checkpoint_filename_with_path
=
_CHECKPOINT
)
checkpoint_filename_with_path
=
_CHECKPOINT
)
with
monitored_session
.
MonitoredSession
(
with
monitored_session
.
MonitoredSession
(
session_creator
=
session_creator
)
as
sess
:
session_creator
=
session_creator
)
as
sess
:
moving_mean_np
=
sess
.
run
(
moving_mean_tf
,
moving_mean_np
=
sess
.
run
(
moving_mean_tf
,
feed_dict
=
{
images_placeholder
:
images_data
})
feed_dict
=
{
images_placeholder
:
images_data
})
...
@@ -50,38 +51,38 @@ class DemoInferenceTest(tf.test.TestCase):
...
@@ -50,38 +51,38 @@ class DemoInferenceTest(tf.test.TestCase):
'fsns'
,
'fsns'
,
image_path_pattern
)
image_path_pattern
)
self
.
assertEqual
([
self
.
assertEqual
([
u
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░'
,
u
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░'
,
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░'
,
'Avenue Charles Gounod░░░░░░░░░░░░░░░░'
,
'Avenue Charles Gounod░░░░░░░░░░░░░░░░'
,
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░'
,
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░'
,
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░'
,
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░'
,
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░'
,
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░'
,
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░'
,
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░'
,
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░'
,
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░'
,
# GT='Rue Thérésa'
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░'
,
# GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░'
,
'Route de la Balme░░░░░░░░░░░░░░░░░░░░'
,
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░'
,
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░'
,
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░'
,
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░'
,
'Place de la Mairie░░░░░░░░░░░░░░░░░░░'
,
'Place de la Mairie░░░░░░░░░░░░░░░░░░░'
,
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de la Libération░░░░░░░░░░░░░░░░░'
,
'Rue de la Libération░░░░░░░░░░░░░░░░░'
,
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░'
,
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░'
,
'Avenue de la Grand Mare░░░░░░░░░░░░░░'
,
'Avenue de la Grand Mare░░░░░░░░░░░░░░'
,
'Rue Pierre Brossolette░░░░░░░░░░░░░░░'
,
'Rue Pierre Brossolette░░░░░░░░░░░░░░░'
,
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░'
,
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░'
,
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░'
,
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░'
,
'Impasse Pierre Mourgues░░░░░░░░░░░░░░'
,
'Impasse Pierre Mourgues░░░░░░░░░░░░░░'
,
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
],
predictions
)
],
predictions
)
...
...
research/attention_ocr/python/eval.py
View file @
0cceabfc
...
@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None,
...
@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None,
def
main
(
_
):
def
main
(
_
):
if
not
tf
.
gfile
.
E
xists
(
FLAGS
.
eval_log_dir
):
if
not
tf
.
io
.
gfile
.
e
xists
(
FLAGS
.
eval_log_dir
):
tf
.
gfile
.
M
ake
D
irs
(
FLAGS
.
eval_log_dir
)
tf
.
io
.
gfile
.
m
ake
d
irs
(
FLAGS
.
eval_log_dir
)
dataset
=
common_flags
.
create_dataset
(
split_name
=
FLAGS
.
split_name
)
dataset
=
common_flags
.
create_dataset
(
split_name
=
FLAGS
.
split_name
)
model
=
common_flags
.
create_model
(
dataset
.
num_char_classes
,
model
=
common_flags
.
create_model
(
dataset
.
num_char_classes
,
...
@@ -62,7 +62,7 @@ def main(_):
...
@@ -62,7 +62,7 @@ def main(_):
eval_ops
=
model
.
create_summaries
(
eval_ops
=
model
.
create_summaries
(
data
,
endpoints
,
dataset
.
charset
,
is_training
=
False
)
data
,
endpoints
,
dataset
.
charset
,
is_training
=
False
)
slim
.
get_or_create_global_step
()
slim
.
get_or_create_global_step
()
session_config
=
tf
.
ConfigProto
(
device_count
=
{
"GPU"
:
0
})
session_config
=
tf
.
compat
.
v1
.
ConfigProto
(
device_count
=
{
"GPU"
:
0
})
slim
.
evaluation
.
evaluation_loop
(
slim
.
evaluation
.
evaluation_loop
(
master
=
FLAGS
.
master
,
master
=
FLAGS
.
master
,
checkpoint_dir
=
FLAGS
.
train_log_dir
,
checkpoint_dir
=
FLAGS
.
train_log_dir
,
...
...
research/attention_ocr/python/inception_preprocessing.py
View file @
0cceabfc
...
@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases):
...
@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases):
The result of func(x, sel), where func receives the value of the
The result of func(x, sel), where func receives the value of the
selector as a python integer, but sel is sampled dynamically.
selector as a python integer, but sel is sampled dynamically.
"""
"""
sel
=
tf
.
random
_
uniform
([],
maxval
=
num_cases
,
dtype
=
tf
.
int32
)
sel
=
tf
.
random
.
uniform
([],
maxval
=
num_cases
,
dtype
=
tf
.
int32
)
# Pass the real x only to one of the func calls.
# Pass the real x only to one of the func calls.
return
control_flow_ops
.
merge
([
return
control_flow_ops
.
merge
([
func
(
control_flow_ops
.
switch
(
x
,
tf
.
equal
(
sel
,
case
))[
1
],
case
)
func
(
control_flow_ops
.
switch
(
x
,
tf
.
equal
(
sel
,
case
))[
1
],
case
)
...
@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
...
@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
Raises:
Raises:
ValueError: if color_ordering not in [0, 3]
ValueError: if color_ordering not in [0, 3]
"""
"""
with
tf
.
name_scope
(
scope
,
'distort_color'
,
[
image
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distort_color'
,
[
image
]):
if
fast_mode
:
if
fast_mode
:
if
color_ordering
==
0
:
if
color_ordering
==
0
:
image
=
tf
.
image
.
random_brightness
(
image
,
max_delta
=
32.
/
255.
)
image
=
tf
.
image
.
random_brightness
(
image
,
max_delta
=
32.
/
255.
)
...
@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image,
...
@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image,
Returns:
Returns:
A tuple, a 3-D Tensor cropped_image and the distorted bbox
A tuple, a 3-D Tensor cropped_image and the distorted bbox
"""
"""
with
tf
.
name_scope
(
scope
,
'distorted_bounding_box_crop'
,
[
image
,
bbox
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distorted_bounding_box_crop'
,
[
image
,
bbox
]):
# Each bounding box has shape [1, num_boxes, box coords] and
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
# the coordinates are ordered [ymin, xmin, ymax, xmax].
...
@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image,
...
@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image,
# bounding box. If no box is supplied, then we assume the bounding box is
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
# the entire image.
sample_distorted_bounding_box
=
tf
.
image
.
sample_distorted_bounding_box
(
sample_distorted_bounding_box
=
tf
.
image
.
sample_distorted_bounding_box
(
tf
.
shape
(
image
),
image_size
=
tf
.
shape
(
input
=
image
),
bounding_boxes
=
bbox
,
bounding_boxes
=
bbox
,
min_object_covered
=
min_object_covered
,
min_object_covered
=
min_object_covered
,
aspect_ratio_range
=
aspect_ratio_range
,
aspect_ratio_range
=
aspect_ratio_range
,
...
@@ -188,7 +188,7 @@ def preprocess_for_train(image,
...
@@ -188,7 +188,7 @@ def preprocess_for_train(image,
Returns:
Returns:
3-D float Tensor of distorted image used for training with range [-1, 1].
3-D float Tensor of distorted image used for training with range [-1, 1].
"""
"""
with
tf
.
name_scope
(
scope
,
'distort_image'
,
[
image
,
height
,
width
,
bbox
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distort_image'
,
[
image
,
height
,
width
,
bbox
]):
if
bbox
is
None
:
if
bbox
is
None
:
bbox
=
tf
.
constant
(
bbox
=
tf
.
constant
(
[
0.0
,
0.0
,
1.0
,
1.0
],
dtype
=
tf
.
float32
,
shape
=
[
1
,
1
,
4
])
[
0.0
,
0.0
,
1.0
,
1.0
],
dtype
=
tf
.
float32
,
shape
=
[
1
,
1
,
4
])
...
@@ -198,7 +198,7 @@ def preprocess_for_train(image,
...
@@ -198,7 +198,7 @@ def preprocess_for_train(image,
# the coordinates are ordered [ymin, xmin, ymax, xmax].
# the coordinates are ordered [ymin, xmin, ymax, xmax].
image_with_box
=
tf
.
image
.
draw_bounding_boxes
(
image_with_box
=
tf
.
image
.
draw_bounding_boxes
(
tf
.
expand_dims
(
image
,
0
),
bbox
)
tf
.
expand_dims
(
image
,
0
),
bbox
)
tf
.
summary
.
image
(
'image_with_bounding_boxes'
,
image_with_box
)
tf
.
compat
.
v1
.
summary
.
image
(
'image_with_bounding_boxes'
,
image_with_box
)
distorted_image
,
distorted_bbox
=
distorted_bounding_box_crop
(
image
,
bbox
)
distorted_image
,
distorted_bbox
=
distorted_bounding_box_crop
(
image
,
bbox
)
# Restore the shape since the dynamic slice based upon the bbox_size loses
# Restore the shape since the dynamic slice based upon the bbox_size loses
...
@@ -206,8 +206,8 @@ def preprocess_for_train(image,
...
@@ -206,8 +206,8 @@ def preprocess_for_train(image,
distorted_image
.
set_shape
([
None
,
None
,
3
])
distorted_image
.
set_shape
([
None
,
None
,
3
])
image_with_distorted_box
=
tf
.
image
.
draw_bounding_boxes
(
image_with_distorted_box
=
tf
.
image
.
draw_bounding_boxes
(
tf
.
expand_dims
(
image
,
0
),
distorted_bbox
)
tf
.
expand_dims
(
image
,
0
),
distorted_bbox
)
tf
.
summary
.
image
(
'images_with_distorted_bounding_box'
,
tf
.
compat
.
v1
.
summary
.
image
(
'images_with_distorted_bounding_box'
,
image_with_distorted_box
)
image_with_distorted_box
)
# This resizing operation may distort the images because the aspect
# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
# ratio is not respected. We select a resize method in a round robin
...
@@ -218,11 +218,11 @@ def preprocess_for_train(image,
...
@@ -218,11 +218,11 @@ def preprocess_for_train(image,
num_resize_cases
=
1
if
fast_mode
else
4
num_resize_cases
=
1
if
fast_mode
else
4
distorted_image
=
apply_with_random_selector
(
distorted_image
=
apply_with_random_selector
(
distorted_image
,
distorted_image
,
lambda
x
,
method
:
tf
.
image
.
resize
_images
(
x
,
[
height
,
width
],
method
=
method
),
lambda
x
,
method
:
tf
.
image
.
resize
(
x
,
[
height
,
width
],
method
=
method
),
num_cases
=
num_resize_cases
)
num_cases
=
num_resize_cases
)
tf
.
summary
.
image
(
'cropped_resized_image'
,
tf
.
compat
.
v1
.
summary
.
image
(
'cropped_resized_image'
,
tf
.
expand_dims
(
distorted_image
,
0
))
tf
.
expand_dims
(
distorted_image
,
0
))
# Randomly flip the image horizontally.
# Randomly flip the image horizontally.
distorted_image
=
tf
.
image
.
random_flip_left_right
(
distorted_image
)
distorted_image
=
tf
.
image
.
random_flip_left_right
(
distorted_image
)
...
@@ -233,8 +233,8 @@ def preprocess_for_train(image,
...
@@ -233,8 +233,8 @@ def preprocess_for_train(image,
lambda
x
,
ordering
:
distort_color
(
x
,
ordering
,
fast_mode
),
lambda
x
,
ordering
:
distort_color
(
x
,
ordering
,
fast_mode
),
num_cases
=
4
)
num_cases
=
4
)
tf
.
summary
.
image
(
'final_distorted_image'
,
tf
.
compat
.
v1
.
summary
.
image
(
'final_distorted_image'
,
tf
.
expand_dims
(
distorted_image
,
0
))
tf
.
expand_dims
(
distorted_image
,
0
))
distorted_image
=
tf
.
subtract
(
distorted_image
,
0.5
)
distorted_image
=
tf
.
subtract
(
distorted_image
,
0.5
)
distorted_image
=
tf
.
multiply
(
distorted_image
,
2.0
)
distorted_image
=
tf
.
multiply
(
distorted_image
,
2.0
)
return
distorted_image
return
distorted_image
...
@@ -265,7 +265,7 @@ def preprocess_for_eval(image,
...
@@ -265,7 +265,7 @@ def preprocess_for_eval(image,
Returns:
Returns:
3-D float Tensor of prepared image.
3-D float Tensor of prepared image.
"""
"""
with
tf
.
name_scope
(
scope
,
'eval_image'
,
[
image
,
height
,
width
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'eval_image'
,
[
image
,
height
,
width
]):
if
image
.
dtype
!=
tf
.
float32
:
if
image
.
dtype
!=
tf
.
float32
:
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
# Crop the central region of the image with an area containing 87.5% of
# Crop the central region of the image with an area containing 87.5% of
...
@@ -276,8 +276,8 @@ def preprocess_for_eval(image,
...
@@ -276,8 +276,8 @@ def preprocess_for_eval(image,
if
height
and
width
:
if
height
and
width
:
# Resize the image to the specified height and width.
# Resize the image to the specified height and width.
image
=
tf
.
expand_dims
(
image
,
0
)
image
=
tf
.
expand_dims
(
image
,
0
)
image
=
tf
.
image
.
resize
_bilinear
(
image
=
tf
.
image
.
resize
(
image
,
[
height
,
width
],
align_corners
=
False
)
image
,
[
height
,
width
],
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
squeeze
(
image
,
[
0
])
image
=
tf
.
squeeze
(
image
,
[
0
])
image
=
tf
.
subtract
(
image
,
0.5
)
image
=
tf
.
subtract
(
image
,
0.5
)
image
=
tf
.
multiply
(
image
,
2.0
)
image
=
tf
.
multiply
(
image
,
2.0
)
...
...
research/attention_ocr/python/metrics.py
View file @
0cceabfc
...
@@ -34,20 +34,21 @@ def char_accuracy(predictions, targets, rej_char, streaming=False):
...
@@ -34,20 +34,21 @@ def char_accuracy(predictions, targets, rej_char, streaming=False):
a update_ops for execution and value tensor whose value on evaluation
a update_ops for execution and value tensor whose value on evaluation
returns the total character accuracy.
returns the total character accuracy.
"""
"""
with
tf
.
variable_scope
(
'CharAccuracy'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'CharAccuracy'
):
predictions
.
get_shape
().
assert_is_compatible_with
(
targets
.
get_shape
())
predictions
.
get_shape
().
assert_is_compatible_with
(
targets
.
get_shape
())
targets
=
tf
.
to_int32
(
targets
)
targets
=
tf
.
cast
(
targets
,
dtype
=
tf
.
int32
)
const_rej_char
=
tf
.
constant
(
rej_char
,
shape
=
targets
.
get_shape
())
const_rej_char
=
tf
.
constant
(
rej_char
,
shape
=
targets
.
get_shape
())
weights
=
tf
.
to_float
(
tf
.
not_equal
(
targets
,
const_rej_char
))
weights
=
tf
.
cast
(
tf
.
not_equal
(
targets
,
const_rej_char
),
dtype
=
tf
.
float32
)
correct_chars
=
tf
.
to_float
(
tf
.
equal
(
predictions
,
targets
))
correct_chars
=
tf
.
cast
(
tf
.
equal
(
predictions
,
targets
),
dtype
=
tf
.
float32
)
accuracy_per_example
=
tf
.
div
(
accuracy_per_example
=
tf
.
compat
.
v1
.
div
(
tf
.
reduce_sum
(
tf
.
multiply
(
correct_chars
,
weights
),
1
),
tf
.
reduce_sum
(
input_tensor
=
tf
.
multiply
(
tf
.
reduce_sum
(
weights
,
1
))
correct_chars
,
weights
),
axis
=
1
),
tf
.
reduce_sum
(
input_tensor
=
weights
,
axis
=
1
))
if
streaming
:
if
streaming
:
return
tf
.
contrib
.
metrics
.
streaming_mean
(
accuracy_per_example
)
return
tf
.
contrib
.
metrics
.
streaming_mean
(
accuracy_per_example
)
else
:
else
:
return
tf
.
reduce_mean
(
accuracy_per_example
)
return
tf
.
reduce_mean
(
input_tensor
=
accuracy_per_example
)
def
sequence_accuracy
(
predictions
,
targets
,
rej_char
,
streaming
=
False
):
def
sequence_accuracy
(
predictions
,
targets
,
rej_char
,
streaming
=
False
):
...
@@ -66,25 +67,26 @@ def sequence_accuracy(predictions, targets, rej_char, streaming=False):
...
@@ -66,25 +67,26 @@ def sequence_accuracy(predictions, targets, rej_char, streaming=False):
returns the total sequence accuracy.
returns the total sequence accuracy.
"""
"""
with
tf
.
variable_scope
(
'SequenceAccuracy'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'SequenceAccuracy'
):
predictions
.
get_shape
().
assert_is_compatible_with
(
targets
.
get_shape
())
predictions
.
get_shape
().
assert_is_compatible_with
(
targets
.
get_shape
())
targets
=
tf
.
to_int32
(
targets
)
targets
=
tf
.
cast
(
targets
,
dtype
=
tf
.
int32
)
const_rej_char
=
tf
.
constant
(
const_rej_char
=
tf
.
constant
(
rej_char
,
shape
=
targets
.
get_shape
(),
dtype
=
tf
.
int32
)
rej_char
,
shape
=
targets
.
get_shape
(),
dtype
=
tf
.
int32
)
include_mask
=
tf
.
not_equal
(
targets
,
const_rej_char
)
include_mask
=
tf
.
not_equal
(
targets
,
const_rej_char
)
include_predictions
=
tf
.
to_int32
(
include_predictions
=
tf
.
cast
(
tf
.
where
(
include_mask
,
predictions
,
tf
.
compat
.
v1
.
where
(
include_mask
,
predictions
,
tf
.
zeros_like
(
predictions
)
+
rej_char
))
tf
.
zeros_like
(
predictions
)
+
rej_char
),
dtype
=
tf
.
int32
)
correct_chars
=
tf
.
to_float
(
tf
.
equal
(
include_predictions
,
targets
))
correct_chars
=
tf
.
cast
(
tf
.
equal
(
include_predictions
,
targets
),
dtype
=
tf
.
float32
)
correct_chars_counts
=
tf
.
cast
(
correct_chars_counts
=
tf
.
cast
(
tf
.
reduce_sum
(
correct_chars
,
reduction_indice
s
=
[
1
]),
dtype
=
tf
.
int32
)
tf
.
reduce_sum
(
input_tensor
=
correct_chars
,
axi
s
=
[
1
]),
dtype
=
tf
.
int32
)
target_length
=
targets
.
get_shape
().
dims
[
1
].
value
target_length
=
targets
.
get_shape
().
dims
[
1
].
value
target_chars_counts
=
tf
.
constant
(
target_chars_counts
=
tf
.
constant
(
target_length
,
shape
=
correct_chars_counts
.
get_shape
())
target_length
,
shape
=
correct_chars_counts
.
get_shape
())
accuracy_per_example
=
tf
.
to_floa
t
(
accuracy_per_example
=
tf
.
cas
t
(
tf
.
equal
(
correct_chars_counts
,
target_chars_counts
))
tf
.
equal
(
correct_chars_counts
,
target_chars_counts
)
,
dtype
=
tf
.
float32
)
if
streaming
:
if
streaming
:
return
tf
.
contrib
.
metrics
.
streaming_mean
(
accuracy_per_example
)
return
tf
.
contrib
.
metrics
.
streaming_mean
(
accuracy_per_example
)
else
:
else
:
return
tf
.
reduce_mean
(
accuracy_per_example
)
return
tf
.
reduce_mean
(
input_tensor
=
accuracy_per_example
)
research/attention_ocr/python/metrics_test.py
View file @
0cceabfc
...
@@ -38,8 +38,8 @@ class AccuracyTest(tf.test.TestCase):
...
@@ -38,8 +38,8 @@ class AccuracyTest(tf.test.TestCase):
A session object that should be used as a context manager.
A session object that should be used as a context manager.
"""
"""
with
self
.
cached_session
()
as
sess
:
with
self
.
cached_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
local_variables_initializer
())
yield
sess
yield
sess
def
_fake_labels
(
self
):
def
_fake_labels
(
self
):
...
@@ -55,7 +55,7 @@ class AccuracyTest(tf.test.TestCase):
...
@@ -55,7 +55,7 @@ class AccuracyTest(tf.test.TestCase):
return
incorrect
return
incorrect
def
test_sequence_accuracy_identical_samples
(
self
):
def
test_sequence_accuracy_identical_samples
(
self
):
labels_tf
=
tf
.
convert_to_tensor
(
self
.
_fake_labels
())
labels_tf
=
tf
.
convert_to_tensor
(
value
=
self
.
_fake_labels
())
accuracy_tf
=
metrics
.
sequence_accuracy
(
labels_tf
,
labels_tf
,
accuracy_tf
=
metrics
.
sequence_accuracy
(
labels_tf
,
labels_tf
,
self
.
rej_char
)
self
.
rej_char
)
...
@@ -66,9 +66,9 @@ class AccuracyTest(tf.test.TestCase):
...
@@ -66,9 +66,9 @@ class AccuracyTest(tf.test.TestCase):
def
test_sequence_accuracy_one_char_difference
(
self
):
def
test_sequence_accuracy_one_char_difference
(
self
):
ground_truth_np
=
self
.
_fake_labels
()
ground_truth_np
=
self
.
_fake_labels
()
ground_truth_tf
=
tf
.
convert_to_tensor
(
ground_truth_np
)
ground_truth_tf
=
tf
.
convert_to_tensor
(
value
=
ground_truth_np
)
prediction_tf
=
tf
.
convert_to_tensor
(
prediction_tf
=
tf
.
convert_to_tensor
(
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
value
=
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
accuracy_tf
=
metrics
.
sequence_accuracy
(
prediction_tf
,
ground_truth_tf
,
accuracy_tf
=
metrics
.
sequence_accuracy
(
prediction_tf
,
ground_truth_tf
,
self
.
rej_char
)
self
.
rej_char
)
...
@@ -80,9 +80,9 @@ class AccuracyTest(tf.test.TestCase):
...
@@ -80,9 +80,9 @@ class AccuracyTest(tf.test.TestCase):
def
test_char_accuracy_one_char_difference_with_padding
(
self
):
def
test_char_accuracy_one_char_difference_with_padding
(
self
):
ground_truth_np
=
self
.
_fake_labels
()
ground_truth_np
=
self
.
_fake_labels
()
ground_truth_tf
=
tf
.
convert_to_tensor
(
ground_truth_np
)
ground_truth_tf
=
tf
.
convert_to_tensor
(
value
=
ground_truth_np
)
prediction_tf
=
tf
.
convert_to_tensor
(
prediction_tf
=
tf
.
convert_to_tensor
(
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
value
=
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
accuracy_tf
=
metrics
.
char_accuracy
(
prediction_tf
,
ground_truth_tf
,
accuracy_tf
=
metrics
.
char_accuracy
(
prediction_tf
,
ground_truth_tf
,
self
.
rej_char
)
self
.
rej_char
)
...
...
research/attention_ocr/python/model.py
View file @
0cceabfc
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Functions to build the Attention OCR model.
"""Functions to build the Attention OCR model.
Usage example:
Usage example:
...
@@ -26,6 +25,7 @@ Usage example:
...
@@ -26,6 +25,7 @@ Usage example:
import
sys
import
sys
import
collections
import
collections
import
logging
import
logging
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib
import
slim
from
tensorflow.contrib
import
slim
from
tensorflow.contrib.slim.nets
import
inception
from
tensorflow.contrib.slim.nets
import
inception
...
@@ -35,29 +35,28 @@ import sequence_layers
...
@@ -35,29 +35,28 @@ import sequence_layers
import
utils
import
utils
OutputEndpoints
=
collections
.
namedtuple
(
'OutputEndpoints'
,
[
OutputEndpoints
=
collections
.
namedtuple
(
'OutputEndpoints'
,
[
'chars_logit'
,
'chars_log_prob'
,
'predicted_chars'
,
'predicted_scores'
,
'chars_logit'
,
'chars_log_prob'
,
'predicted_chars'
,
'predicted_scores'
,
'predicted_text'
'predicted_text'
,
'predicted_length'
,
'predicted_conf'
,
'normalized_seq_conf'
])
])
# TODO(gorban): replace with tf.HParams when it is released.
# TODO(gorban): replace with tf.HParams when it is released.
ModelParams
=
collections
.
namedtuple
(
'ModelParams'
,
[
ModelParams
=
collections
.
namedtuple
(
'num_char_classes'
,
'seq_length'
,
'num_views'
,
'null_code'
'ModelParams'
,
[
'num_char_classes'
,
'seq_length'
,
'num_views'
,
'null_code'
])
])
ConvTowerParams
=
collections
.
namedtuple
(
'ConvTowerParams'
,
[
'final_endpoint'
])
ConvTowerParams
=
collections
.
namedtuple
(
'ConvTowerParams'
,
[
'final_endpoint'
])
SequenceLogitsParams
=
collections
.
namedtuple
(
'SequenceLogitsParams'
,
[
SequenceLogitsParams
=
collections
.
namedtuple
(
'SequenceLogitsParams'
,
[
'use_attention'
,
'use_autoregression'
,
'num_lstm_units'
,
'weight_decay'
,
'use_attention'
,
'use_autoregression'
,
'num_lstm_units'
,
'weight_decay'
,
'lstm_state_clip_value'
'lstm_state_clip_value'
])
])
SequenceLossParams
=
collections
.
namedtuple
(
'SequenceLossParams'
,
[
SequenceLossParams
=
collections
.
namedtuple
(
'label_smoothing'
,
'ignore_nulls'
,
'average_across_timestep
s'
'SequenceLossParam
s'
,
])
[
'label_smoothing'
,
'ignore_nulls'
,
'average_across_timesteps'
])
EncodeCoordinatesParams
=
collections
.
namedtuple
(
'EncodeCoordinatesParams'
,
[
EncodeCoordinatesParams
=
collections
.
namedtuple
(
'EncodeCoordinatesParams'
,
'enabled'
[
'enabled'
])
])
def
_dict_to_array
(
id_to_char
,
default_character
):
def
_dict_to_array
(
id_to_char
,
default_character
):
...
@@ -85,16 +84,16 @@ class CharsetMapper(object):
...
@@ -85,16 +84,16 @@ class CharsetMapper(object):
"""
"""
mapping_strings
=
tf
.
constant
(
_dict_to_array
(
charset
,
default_character
))
mapping_strings
=
tf
.
constant
(
_dict_to_array
(
charset
,
default_character
))
self
.
table
=
tf
.
contrib
.
lookup
.
index_to_string_table_from_tensor
(
self
.
table
=
tf
.
contrib
.
lookup
.
index_to_string_table_from_tensor
(
mapping
=
mapping_strings
,
default_value
=
default_character
)
mapping
=
mapping_strings
,
default_value
=
default_character
)
def
get_text
(
self
,
ids
):
def
get_text
(
self
,
ids
):
"""Returns a string corresponding to a sequence of character ids.
"""Returns a string corresponding to a sequence of character ids.
Args:
Args:
ids: a tensor with shape [batch_size, max_sequence_length]
ids: a tensor with shape [batch_size, max_sequence_length]
"""
"""
return
tf
.
reduce_join
(
return
tf
.
strings
.
reduce_join
(
self
.
table
.
lookup
(
tf
.
to_int64
(
ids
)),
reduction_indice
s
=
1
)
inputs
=
self
.
table
.
lookup
(
tf
.
cast
(
ids
,
dtype
=
tf
.
int64
)),
axi
s
=
1
)
def
get_softmax_loss_fn
(
label_smoothing
):
def
get_softmax_loss_fn
(
label_smoothing
):
...
@@ -111,16 +110,153 @@ def get_softmax_loss_fn(label_smoothing):
...
@@ -111,16 +110,153 @@ def get_softmax_loss_fn(label_smoothing):
def
loss_fn
(
labels
,
logits
):
def
loss_fn
(
labels
,
logits
):
return
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
return
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
labels
))
logits
=
logits
,
labels
=
tf
.
stop_gradient
(
labels
))
)
else
:
else
:
def
loss_fn
(
labels
,
logits
):
def
loss_fn
(
labels
,
logits
):
return
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
return
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
labels
)
logits
=
logits
,
labels
=
labels
)
return
loss_fn
return
loss_fn
def
get_tensor_dimensions
(
tensor
):
"""Returns the shape components of a 4D tensor with variable batch size.
Args:
tensor : A 4D tensor, whose last 3 dimensions are known at graph
construction time.
Returns:
batch_size : The first dimension as a tensor object.
height : The second dimension as a scalar value.
width : The third dimension as a scalar value.
num_features : The forth dimension as a scalar value.
Raises:
ValueError: if input tensor does not have 4 dimensions.
"""
if
len
(
tensor
.
get_shape
().
dims
)
!=
4
:
raise
ValueError
(
'Incompatible shape: len(tensor.get_shape().dims) != 4 (%d != 4)'
%
len
(
tensor
.
get_shape
().
dims
))
batch_size
=
tf
.
shape
(
input
=
tensor
)[
0
]
height
=
tensor
.
get_shape
().
dims
[
1
].
value
width
=
tensor
.
get_shape
().
dims
[
2
].
value
num_features
=
tensor
.
get_shape
().
dims
[
3
].
value
return
batch_size
,
height
,
width
,
num_features
def
lookup_indexed_value
(
indices
,
row_vecs
):
"""Lookup values in each row of 'row_vecs' indexed by 'indices'.
For each sample in the batch, look up the element for the corresponding
index.
Args:
indices : A tensor of shape (batch, )
row_vecs : A tensor of shape [batch, depth]
Returns:
A tensor of shape (batch, ) formed by row_vecs[i, indices[i]].
"""
gather_indices
=
tf
.
stack
((
tf
.
range
(
tf
.
shape
(
input
=
row_vecs
)[
0
],
dtype
=
tf
.
int32
),
tf
.
cast
(
indices
,
tf
.
int32
)),
axis
=
1
)
return
tf
.
gather_nd
(
row_vecs
,
gather_indices
)
@
utils
.
ConvertAllInputsToTensors
def
max_char_logprob_cumsum
(
char_log_prob
):
"""Computes the cumulative sum of character logprob for all sequence lengths.
Args:
char_log_prob: A tensor of shape [batch x seq_length x num_char_classes]
with log probabilities of a character.
Returns:
A tensor of shape [batch x (seq_length+1)] where each element x[_, j] is
the sum of the max char logprob for all positions upto j.
Note this duplicates the final column and produces (seq_length+1) columns
so the same function can be used regardless whether use_length_predictions
is true or false.
"""
max_char_log_prob
=
tf
.
reduce_max
(
input_tensor
=
char_log_prob
,
axis
=
2
)
# For an input array [a, b, c]) tf.cumsum returns [a, a + b, a + b + c] if
# exclusive set to False (default).
return
tf
.
cumsum
(
max_char_log_prob
,
axis
=
1
,
exclusive
=
False
)
def
find_length_by_null
(
predicted_chars
,
null_code
):
"""Determine sequence length by finding null_code among predicted char IDs.
Given the char class ID for each position, compute the sequence length.
Note that this function computes this based on the number of null_code,
instead of the position of the first null_code.
Args:
predicted_chars: A tensor of [batch x seq_length] where each element stores
the char class ID with max probability;
null_code: an int32, character id for the NULL.
Returns:
A [batch, ] tensor which stores the sequence length for each sample.
"""
return
tf
.
reduce_sum
(
input_tensor
=
tf
.
cast
(
tf
.
not_equal
(
null_code
,
predicted_chars
),
tf
.
int32
),
axis
=
1
)
def
axis_pad
(
tensor
,
axis
,
before
=
0
,
after
=
0
,
constant_values
=
0.0
):
"""Pad a tensor with the specified values along a single axis.
Args:
tensor: a Tensor;
axis: the dimension to add pad along to;
before: number of values to add before the contents of tensor in the
selected dimension;
after: number of values to add after the contents of tensor in the selected
dimension;
constant_values: the scalar pad value to use. Must be same type as tensor.
Returns:
A Tensor. Has the same type as the input tensor, but with a changed shape
along the specified dimension.
"""
if
before
==
0
and
after
==
0
:
return
tensor
ndims
=
tensor
.
shape
.
ndims
padding_size
=
np
.
zeros
((
ndims
,
2
),
dtype
=
'int32'
)
padding_size
[
axis
]
=
before
,
after
return
tf
.
pad
(
tensor
=
tensor
,
paddings
=
tf
.
constant
(
padding_size
),
constant_values
=
constant_values
)
def
null_based_length_prediction
(
chars_log_prob
,
null_code
):
"""Computes length and confidence of prediction based on positions of NULLs.
Args:
chars_log_prob: A tensor of shape [batch x seq_length x num_char_classes]
with log probabilities of a character;
null_code: an int32, character id for the NULL.
Returns:
A tuple (text_log_prob, predicted_length), where
text_log_prob - is a tensor of the same shape as length_log_prob.
Element #0 of the output corresponds to probability of the empty string,
element #seq_length - is the probability of length=seq_length.
predicted_length is a tensor with shape [batch].
"""
predicted_chars
=
tf
.
cast
(
tf
.
argmax
(
input
=
chars_log_prob
,
axis
=
2
),
dtype
=
tf
.
int32
)
# We do right pad to support sequences with seq_length elements.
text_log_prob
=
max_char_logprob_cumsum
(
axis_pad
(
chars_log_prob
,
axis
=
1
,
after
=
1
))
predicted_length
=
find_length_by_null
(
predicted_chars
,
null_code
)
return
text_log_prob
,
predicted_length
class
Model
(
object
):
class
Model
(
object
):
"""Class to create the Attention OCR Model."""
"""Class to create the Attention OCR Model."""
...
@@ -137,24 +273,24 @@ class Model(object):
...
@@ -137,24 +273,24 @@ class Model(object):
num_char_classes: size of character set.
num_char_classes: size of character set.
seq_length: number of characters in a sequence.
seq_length: number of characters in a sequence.
num_views: Number of views (conv towers) to use.
num_views: Number of views (conv towers) to use.
null_code: A character code corresponding to a character which
null_code: A character code corresponding to a character which
indicates
indicates
end of a sequence.
end of a sequence.
mparams: a dictionary with hyper parameters for methods, keys -
mparams: a dictionary with hyper parameters for methods, keys -
function
function
names, values - corresponding namedtuples.
names, values - corresponding namedtuples.
charset: an optional dictionary with a mapping between character ids and
charset: an optional dictionary with a mapping between character ids and
utf8 strings. If specified the OutputEndpoints.predicted_text will
utf8 strings. If specified the OutputEndpoints.predicted_text will
utf8
utf8
encoded strings corresponding to the character ids returned by
encoded strings corresponding to the character ids returned by
OutputEndpoints.predicted_chars (by default the predicted_text contains
OutputEndpoints.predicted_chars (by default the predicted_text contains
an empty vector).
an empty vector).
NOTE: Make sure you call tf.tables_initializer().run() if the charset
NOTE: Make sure you call tf.tables_initializer().run() if the charset
specified.
specified.
"""
"""
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
self
.
_params
=
ModelParams
(
self
.
_params
=
ModelParams
(
num_char_classes
=
num_char_classes
,
num_char_classes
=
num_char_classes
,
seq_length
=
seq_length
,
seq_length
=
seq_length
,
num_views
=
num_views
,
num_views
=
num_views
,
null_code
=
null_code
)
null_code
=
null_code
)
self
.
_mparams
=
self
.
default_mparams
()
self
.
_mparams
=
self
.
default_mparams
()
if
mparams
:
if
mparams
:
self
.
_mparams
.
update
(
mparams
)
self
.
_mparams
.
update
(
mparams
)
...
@@ -162,21 +298,22 @@ class Model(object):
...
@@ -162,21 +298,22 @@ class Model(object):
def
default_mparams
(
self
):
def
default_mparams
(
self
):
return
{
return
{
'conv_tower_fn'
:
'conv_tower_fn'
:
ConvTowerParams
(
final_endpoint
=
'Mixed_5d'
),
ConvTowerParams
(
final_endpoint
=
'Mixed_5d'
),
'sequence_logit_fn'
:
'sequence_logit_fn'
:
SequenceLogitsParams
(
SequenceLogitsParams
(
use_attention
=
True
,
use_attention
=
True
,
use_autoregression
=
True
,
use_autoregression
=
True
,
num_lstm_units
=
256
,
num_lstm_units
=
256
,
weight_decay
=
0.00004
,
weight_decay
=
0.00004
,
lstm_state_clip_value
=
10.0
),
lstm_state_clip_value
=
10.0
),
'sequence_loss_fn'
:
'sequence_loss_fn'
:
SequenceLossParams
(
SequenceLossParams
(
label_smoothing
=
0.1
,
label_smoothing
=
0.1
,
ignore_nulls
=
True
,
ignore_nulls
=
True
,
average_across_timesteps
=
False
),
average_across_timesteps
=
False
),
'encode_coordinates_fn'
:
EncodeCoordinatesParams
(
enabled
=
False
)
'encode_coordinates_fn'
:
EncodeCoordinatesParams
(
enabled
=
False
)
}
}
def
set_mparam
(
self
,
function
,
**
kwargs
):
def
set_mparam
(
self
,
function
,
**
kwargs
):
...
@@ -198,14 +335,14 @@ class Model(object):
...
@@ -198,14 +335,14 @@ class Model(object):
"""
"""
mparams
=
self
.
_mparams
[
'conv_tower_fn'
]
mparams
=
self
.
_mparams
[
'conv_tower_fn'
]
logging
.
debug
(
'Using final_endpoint=%s'
,
mparams
.
final_endpoint
)
logging
.
debug
(
'Using final_endpoint=%s'
,
mparams
.
final_endpoint
)
with
tf
.
variable_scope
(
'conv_tower_fn/INCE'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'conv_tower_fn/INCE'
):
if
reuse
:
if
reuse
:
tf
.
get_variable_scope
().
reuse_variables
()
tf
.
compat
.
v1
.
get_variable_scope
().
reuse_variables
()
with
slim
.
arg_scope
(
inception
.
inception_v3_arg_scope
()):
with
slim
.
arg_scope
(
inception
.
inception_v3_arg_scope
()):
with
slim
.
arg_scope
([
slim
.
batch_norm
,
slim
.
dropout
],
with
slim
.
arg_scope
([
slim
.
batch_norm
,
slim
.
dropout
],
is_training
=
is_training
):
is_training
=
is_training
):
net
,
_
=
inception
.
inception_v3_base
(
net
,
_
=
inception
.
inception_v3_base
(
images
,
final_endpoint
=
mparams
.
final_endpoint
)
images
,
final_endpoint
=
mparams
.
final_endpoint
)
return
net
return
net
def
_create_lstm_inputs
(
self
,
net
):
def
_create_lstm_inputs
(
self
,
net
):
...
@@ -222,10 +359,10 @@ class Model(object):
...
@@ -222,10 +359,10 @@ class Model(object):
"""
"""
num_features
=
net
.
get_shape
().
dims
[
1
].
value
num_features
=
net
.
get_shape
().
dims
[
1
].
value
if
num_features
<
self
.
_params
.
seq_length
:
if
num_features
<
self
.
_params
.
seq_length
:
raise
AssertionError
(
'Incorrect dimension #1 of input tensor'
raise
AssertionError
(
' %d should be bigger than %d (shape=%s)'
%
'Incorrect dimension #1 of input tensor'
(
num_features
,
self
.
_params
.
seq_length
,
' %d should be bigger than %d (shape=%s)'
%
net
.
get_shape
()))
(
num_features
,
self
.
_params
.
seq_length
,
net
.
get_shape
()))
elif
num_features
>
self
.
_params
.
seq_length
:
elif
num_features
>
self
.
_params
.
seq_length
:
logging
.
warning
(
'Ignoring some features: use %d of %d (shape=%s)'
,
logging
.
warning
(
'Ignoring some features: use %d of %d (shape=%s)'
,
self
.
_params
.
seq_length
,
num_features
,
net
.
get_shape
())
self
.
_params
.
seq_length
,
num_features
,
net
.
get_shape
())
...
@@ -236,7 +373,7 @@ class Model(object):
...
@@ -236,7 +373,7 @@ class Model(object):
def
sequence_logit_fn
(
self
,
net
,
labels_one_hot
):
def
sequence_logit_fn
(
self
,
net
,
labels_one_hot
):
mparams
=
self
.
_mparams
[
'sequence_logit_fn'
]
mparams
=
self
.
_mparams
[
'sequence_logit_fn'
]
# TODO(gorban): remove /alias suffixes from the scopes.
# TODO(gorban): remove /alias suffixes from the scopes.
with
tf
.
variable_scope
(
'sequence_logit_fn/SQLR'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'sequence_logit_fn/SQLR'
):
layer_class
=
sequence_layers
.
get_layer_class
(
mparams
.
use_attention
,
layer_class
=
sequence_layers
.
get_layer_class
(
mparams
.
use_attention
,
mparams
.
use_autoregression
)
mparams
.
use_autoregression
)
layer
=
layer_class
(
net
,
labels_one_hot
,
self
.
_params
,
mparams
)
layer
=
layer_class
(
net
,
labels_one_hot
,
self
.
_params
,
mparams
)
...
@@ -252,16 +389,16 @@ class Model(object):
...
@@ -252,16 +389,16 @@ class Model(object):
A tensor with the same size as any input tensors.
A tensor with the same size as any input tensors.
"""
"""
batch_size
,
height
,
width
,
num_features
=
[
batch_size
,
height
,
width
,
num_features
=
[
d
.
value
for
d
in
nets_list
[
0
].
get_shape
().
dims
d
.
value
for
d
in
nets_list
[
0
].
get_shape
().
dims
]
]
xy_flat_shape
=
(
batch_size
,
1
,
height
*
width
,
num_features
)
xy_flat_shape
=
(
batch_size
,
1
,
height
*
width
,
num_features
)
nets_for_merge
=
[]
nets_for_merge
=
[]
with
tf
.
variable_scope
(
'max_pool_views'
,
values
=
nets_list
):
with
tf
.
compat
.
v1
.
variable_scope
(
'max_pool_views'
,
values
=
nets_list
):
for
net
in
nets_list
:
for
net
in
nets_list
:
nets_for_merge
.
append
(
tf
.
reshape
(
net
,
xy_flat_shape
))
nets_for_merge
.
append
(
tf
.
reshape
(
net
,
xy_flat_shape
))
merged_net
=
tf
.
concat
(
nets_for_merge
,
1
)
merged_net
=
tf
.
concat
(
nets_for_merge
,
1
)
net
=
slim
.
max_pool2d
(
net
=
slim
.
max_pool2d
(
merged_net
,
kernel_size
=
[
len
(
nets_list
),
1
],
stride
=
1
)
merged_net
,
kernel_size
=
[
len
(
nets_list
),
1
],
stride
=
1
)
net
=
tf
.
reshape
(
net
,
(
batch_size
,
height
,
width
,
num_features
))
net
=
tf
.
reshape
(
net
,
(
batch_size
,
height
,
width
,
num_features
))
return
net
return
net
...
@@ -277,18 +414,20 @@ class Model(object):
...
@@ -277,18 +414,20 @@ class Model(object):
Returns:
Returns:
A tensor of shape [batch_size, seq_length, features_size].
A tensor of shape [batch_size, seq_length, features_size].
"""
"""
with
tf
.
variable_scope
(
'pool_views_fn/STCK'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'pool_views_fn/STCK'
):
net
=
tf
.
concat
(
nets
,
1
)
net
=
tf
.
concat
(
nets
,
1
)
batch_size
=
net
.
get_shape
().
dims
[
0
].
value
batch_size
=
tf
.
shape
(
input
=
net
)[
0
]
image_size
=
net
.
get_shape
().
dims
[
1
].
value
*
\
net
.
get_shape
().
dims
[
2
].
value
feature_size
=
net
.
get_shape
().
dims
[
3
].
value
feature_size
=
net
.
get_shape
().
dims
[
3
].
value
return
tf
.
reshape
(
net
,
[
batch_size
,
-
1
,
feature_size
])
return
tf
.
reshape
(
net
,
tf
.
stack
(
[
batch_size
,
image_size
,
feature_size
])
)
def
char_predictions
(
self
,
chars_logit
):
def
char_predictions
(
self
,
chars_logit
):
"""Returns confidence scores (softmax values) for predicted characters.
"""Returns confidence scores (softmax values) for predicted characters.
Args:
Args:
chars_logit: chars logits, a tensor with shape
chars_logit: chars logits, a tensor with shape
[batch_size x seq_length x
[batch_size x seq_length x
num_char_classes]
num_char_classes]
Returns:
Returns:
A tuple (ids, log_prob, scores), where:
A tuple (ids, log_prob, scores), where:
...
@@ -301,12 +440,17 @@ class Model(object):
...
@@ -301,12 +440,17 @@ class Model(object):
with shape [batch_size x seq_length].
with shape [batch_size x seq_length].
"""
"""
log_prob
=
utils
.
logits_to_log_prob
(
chars_logit
)
log_prob
=
utils
.
logits_to_log_prob
(
chars_logit
)
ids
=
tf
.
to_int32
(
tf
.
argmax
(
log_prob
,
axis
=
2
),
name
=
'predicted_chars'
)
ids
=
tf
.
cast
(
tf
.
argmax
(
input
=
log_prob
,
axis
=
2
),
name
=
'predicted_chars'
,
dtype
=
tf
.
int32
)
mask
=
tf
.
cast
(
mask
=
tf
.
cast
(
slim
.
one_hot_encoding
(
ids
,
self
.
_params
.
num_char_classes
),
tf
.
bool
)
slim
.
one_hot_encoding
(
ids
,
self
.
_params
.
num_char_classes
),
tf
.
bool
)
all_scores
=
tf
.
nn
.
softmax
(
chars_logit
)
all_scores
=
tf
.
nn
.
softmax
(
chars_logit
)
selected_scores
=
tf
.
boolean_mask
(
all_scores
,
mask
,
name
=
'char_scores'
)
selected_scores
=
tf
.
boolean_mask
(
scores
=
tf
.
reshape
(
selected_scores
,
shape
=
(
-
1
,
self
.
_params
.
seq_length
))
tensor
=
all_scores
,
mask
=
mask
,
name
=
'char_scores'
)
scores
=
tf
.
reshape
(
selected_scores
,
shape
=
(
-
1
,
self
.
_params
.
seq_length
),
name
=
'predicted_scores'
)
return
ids
,
log_prob
,
scores
return
ids
,
log_prob
,
scores
def
encode_coordinates_fn
(
self
,
net
):
def
encode_coordinates_fn
(
self
,
net
):
...
@@ -323,12 +467,12 @@ class Model(object):
...
@@ -323,12 +467,12 @@ class Model(object):
"""
"""
mparams
=
self
.
_mparams
[
'encode_coordinates_fn'
]
mparams
=
self
.
_mparams
[
'encode_coordinates_fn'
]
if
mparams
.
enabled
:
if
mparams
.
enabled
:
batch_size
,
h
,
w
,
_
=
n
et
.
shape
.
as_list
(
)
batch_size
,
h
,
w
,
_
=
g
et
_tensor_dimensions
(
net
)
x
,
y
=
tf
.
meshgrid
(
tf
.
range
(
w
),
tf
.
range
(
h
))
x
,
y
=
tf
.
meshgrid
(
tf
.
range
(
w
),
tf
.
range
(
h
))
w_loc
=
slim
.
one_hot_encoding
(
x
,
num_classes
=
w
)
w_loc
=
slim
.
one_hot_encoding
(
x
,
num_classes
=
w
)
h_loc
=
slim
.
one_hot_encoding
(
y
,
num_classes
=
h
)
h_loc
=
slim
.
one_hot_encoding
(
y
,
num_classes
=
h
)
loc
=
tf
.
concat
([
h_loc
,
w_loc
],
2
)
loc
=
tf
.
concat
([
h_loc
,
w_loc
],
2
)
loc
=
tf
.
tile
(
tf
.
expand_dims
(
loc
,
0
),
[
batch_size
,
1
,
1
,
1
])
loc
=
tf
.
tile
(
tf
.
expand_dims
(
loc
,
0
),
tf
.
stack
(
[
batch_size
,
1
,
1
,
1
])
)
return
tf
.
concat
([
net
,
loc
],
3
)
return
tf
.
concat
([
net
,
loc
],
3
)
else
:
else
:
return
net
return
net
...
@@ -341,7 +485,8 @@ class Model(object):
...
@@ -341,7 +485,8 @@ class Model(object):
"""Creates a base part of the Model (no gradients, losses or summaries).
"""Creates a base part of the Model (no gradients, losses or summaries).
Args:
Args:
images: A tensor of shape [batch_size, height, width, channels].
images: A tensor of shape [batch_size, height, width, channels] with pixel
values in the range [0.0, 1.0].
labels_one_hot: Optional (can be None) one-hot encoding for ground truth
labels_one_hot: Optional (can be None) one-hot encoding for ground truth
labels. If provided the function will create a model for training.
labels. If provided the function will create a model for training.
scope: Optional variable_scope.
scope: Optional variable_scope.
...
@@ -353,14 +498,19 @@ class Model(object):
...
@@ -353,14 +498,19 @@ class Model(object):
"""
"""
logging
.
debug
(
'images: %s'
,
images
)
logging
.
debug
(
'images: %s'
,
images
)
is_training
=
labels_one_hot
is
not
None
is_training
=
labels_one_hot
is
not
None
with
tf
.
variable_scope
(
scope
,
reuse
=
reuse
):
# Normalize image pixel values to have a symmetrical range around zero.
images
=
tf
.
subtract
(
images
,
0.5
)
images
=
tf
.
multiply
(
images
,
2.5
)
with
tf
.
compat
.
v1
.
variable_scope
(
scope
,
reuse
=
reuse
):
views
=
tf
.
split
(
views
=
tf
.
split
(
value
=
images
,
num_or_size_splits
=
self
.
_params
.
num_views
,
axis
=
2
)
value
=
images
,
num_or_size_splits
=
self
.
_params
.
num_views
,
axis
=
2
)
logging
.
debug
(
'Views=%d single view: %s'
,
len
(
views
),
views
[
0
])
logging
.
debug
(
'Views=%d single view: %s'
,
len
(
views
),
views
[
0
])
nets
=
[
nets
=
[
self
.
conv_tower_fn
(
v
,
is_training
,
reuse
=
(
i
!=
0
))
self
.
conv_tower_fn
(
v
,
is_training
,
reuse
=
(
i
!=
0
))
for
i
,
v
in
enumerate
(
views
)
for
i
,
v
in
enumerate
(
views
)
]
]
logging
.
debug
(
'Conv tower: %s'
,
nets
[
0
])
logging
.
debug
(
'Conv tower: %s'
,
nets
[
0
])
...
@@ -374,18 +524,34 @@ class Model(object):
...
@@ -374,18 +524,34 @@ class Model(object):
logging
.
debug
(
'chars_logit: %s'
,
chars_logit
)
logging
.
debug
(
'chars_logit: %s'
,
chars_logit
)
predicted_chars
,
chars_log_prob
,
predicted_scores
=
(
predicted_chars
,
chars_log_prob
,
predicted_scores
=
(
self
.
char_predictions
(
chars_logit
))
self
.
char_predictions
(
chars_logit
))
if
self
.
_charset
:
if
self
.
_charset
:
character_mapper
=
CharsetMapper
(
self
.
_charset
)
character_mapper
=
CharsetMapper
(
self
.
_charset
)
predicted_text
=
character_mapper
.
get_text
(
predicted_chars
)
predicted_text
=
character_mapper
.
get_text
(
predicted_chars
)
else
:
else
:
predicted_text
=
tf
.
constant
([])
predicted_text
=
tf
.
constant
([])
text_log_prob
,
predicted_length
=
null_based_length_prediction
(
chars_log_prob
,
self
.
_params
.
null_code
)
predicted_conf
=
lookup_indexed_value
(
predicted_length
,
text_log_prob
)
# Convert predicted confidence from sum of logs to geometric mean
normalized_seq_conf
=
tf
.
exp
(
tf
.
divide
(
predicted_conf
,
tf
.
cast
(
predicted_length
+
1
,
predicted_conf
.
dtype
)),
name
=
'normalized_seq_conf'
)
predicted_conf
=
tf
.
identity
(
predicted_conf
,
name
=
'predicted_conf'
)
predicted_text
=
tf
.
identity
(
predicted_text
,
name
=
'predicted_text'
)
predicted_length
=
tf
.
identity
(
predicted_length
,
name
=
'predicted_length'
)
return
OutputEndpoints
(
return
OutputEndpoints
(
chars_logit
=
chars_logit
,
chars_logit
=
chars_logit
,
chars_log_prob
=
chars_log_prob
,
chars_log_prob
=
chars_log_prob
,
predicted_chars
=
predicted_chars
,
predicted_chars
=
predicted_chars
,
predicted_scores
=
predicted_scores
,
predicted_scores
=
predicted_scores
,
predicted_text
=
predicted_text
)
predicted_length
=
predicted_length
,
predicted_text
=
predicted_text
,
predicted_conf
=
predicted_conf
,
normalized_seq_conf
=
normalized_seq_conf
)
def
create_loss
(
self
,
data
,
endpoints
):
def
create_loss
(
self
,
data
,
endpoints
):
"""Creates all losses required to train the model.
"""Creates all losses required to train the model.
...
@@ -404,7 +570,7 @@ class Model(object):
...
@@ -404,7 +570,7 @@ class Model(object):
# multiple losses including regularization losses.
# multiple losses including regularization losses.
self
.
sequence_loss_fn
(
endpoints
.
chars_logit
,
data
.
labels
)
self
.
sequence_loss_fn
(
endpoints
.
chars_logit
,
data
.
labels
)
total_loss
=
slim
.
losses
.
get_total_loss
()
total_loss
=
slim
.
losses
.
get_total_loss
()
tf
.
summary
.
scalar
(
'TotalLoss'
,
total_loss
)
tf
.
compat
.
v1
.
summary
.
scalar
(
'TotalLoss'
,
total_loss
)
return
total_loss
return
total_loss
def
label_smoothing_regularization
(
self
,
chars_labels
,
weight
=
0.1
):
def
label_smoothing_regularization
(
self
,
chars_labels
,
weight
=
0.1
):
...
@@ -413,15 +579,15 @@ class Model(object):
...
@@ -413,15 +579,15 @@ class Model(object):
Uses the same method as in https://arxiv.org/abs/1512.00567.
Uses the same method as in https://arxiv.org/abs/1512.00567.
Args:
Args:
chars_labels: ground truth ids of charactes,
chars_labels: ground truth ids of charactes,
shape=[batch_size,
shape=[batch_size,
seq_length];
seq_length];
weight: label-smoothing regularization weight.
weight: label-smoothing regularization weight.
Returns:
Returns:
A sensor with the same shape as the input.
A sensor with the same shape as the input.
"""
"""
one_hot_labels
=
tf
.
one_hot
(
one_hot_labels
=
tf
.
one_hot
(
chars_labels
,
depth
=
self
.
_params
.
num_char_classes
,
axis
=-
1
)
chars_labels
,
depth
=
self
.
_params
.
num_char_classes
,
axis
=-
1
)
pos_weight
=
1.0
-
weight
pos_weight
=
1.0
-
weight
neg_weight
=
weight
/
self
.
_params
.
num_char_classes
neg_weight
=
weight
/
self
.
_params
.
num_char_classes
return
one_hot_labels
*
pos_weight
+
neg_weight
return
one_hot_labels
*
pos_weight
+
neg_weight
...
@@ -433,20 +599,20 @@ class Model(object):
...
@@ -433,20 +599,20 @@ class Model(object):
also ignore all null chars after the first one.
also ignore all null chars after the first one.
Args:
Args:
chars_logits: logits for predicted characters,
chars_logits: logits for predicted characters,
shape=[batch_size,
shape=[batch_size,
seq_length, num_char_classes];
seq_length, num_char_classes];
chars_labels: ground truth ids of characters,
chars_labels: ground truth ids of characters,
shape=[batch_size,
shape=[batch_size,
seq_length];
seq_length];
mparams: method hyper parameters.
mparams: method hyper parameters.
Returns:
Returns:
A Tensor with shape [batch_size] - the log-perplexity for each sequence.
A Tensor with shape [batch_size] - the log-perplexity for each sequence.
"""
"""
mparams
=
self
.
_mparams
[
'sequence_loss_fn'
]
mparams
=
self
.
_mparams
[
'sequence_loss_fn'
]
with
tf
.
variable_scope
(
'sequence_loss_fn/SLF'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'sequence_loss_fn/SLF'
):
if
mparams
.
label_smoothing
>
0
:
if
mparams
.
label_smoothing
>
0
:
smoothed_one_hot_labels
=
self
.
label_smoothing_regularization
(
smoothed_one_hot_labels
=
self
.
label_smoothing_regularization
(
chars_labels
,
mparams
.
label_smoothing
)
chars_labels
,
mparams
.
label_smoothing
)
labels_list
=
tf
.
unstack
(
smoothed_one_hot_labels
,
axis
=
1
)
labels_list
=
tf
.
unstack
(
smoothed_one_hot_labels
,
axis
=
1
)
else
:
else
:
# NOTE: in case of sparse softmax we are not using one-hot
# NOTE: in case of sparse softmax we are not using one-hot
...
@@ -459,21 +625,21 @@ class Model(object):
...
@@ -459,21 +625,21 @@ class Model(object):
else
:
else
:
# Suppose that reject character is the last in the charset.
# Suppose that reject character is the last in the charset.
reject_char
=
tf
.
constant
(
reject_char
=
tf
.
constant
(
self
.
_params
.
num_char_classes
-
1
,
self
.
_params
.
num_char_classes
-
1
,
shape
=
(
batch_size
,
seq_length
),
shape
=
(
batch_size
,
seq_length
),
dtype
=
tf
.
int64
)
dtype
=
tf
.
int64
)
known_char
=
tf
.
not_equal
(
chars_labels
,
reject_char
)
known_char
=
tf
.
not_equal
(
chars_labels
,
reject_char
)
weights
=
tf
.
to_floa
t
(
known_char
)
weights
=
tf
.
cas
t
(
known_char
,
dtype
=
tf
.
float32
)
logits_list
=
tf
.
unstack
(
chars_logits
,
axis
=
1
)
logits_list
=
tf
.
unstack
(
chars_logits
,
axis
=
1
)
weights_list
=
tf
.
unstack
(
weights
,
axis
=
1
)
weights_list
=
tf
.
unstack
(
weights
,
axis
=
1
)
loss
=
tf
.
contrib
.
legacy_seq2seq
.
sequence_loss
(
loss
=
tf
.
contrib
.
legacy_seq2seq
.
sequence_loss
(
logits_list
,
logits_list
,
labels_list
,
labels_list
,
weights_list
,
weights_list
,
softmax_loss_function
=
get_softmax_loss_fn
(
mparams
.
label_smoothing
),
softmax_loss_function
=
get_softmax_loss_fn
(
mparams
.
label_smoothing
),
average_across_timesteps
=
mparams
.
average_across_timesteps
)
average_across_timesteps
=
mparams
.
average_across_timesteps
)
tf
.
losses
.
add_loss
(
loss
)
tf
.
compat
.
v1
.
losses
.
add_loss
(
loss
)
return
loss
return
loss
def
create_summaries
(
self
,
data
,
endpoints
,
charset
,
is_training
):
def
create_summaries
(
self
,
data
,
endpoints
,
charset
,
is_training
):
...
@@ -482,8 +648,8 @@ class Model(object):
...
@@ -482,8 +648,8 @@ class Model(object):
Args:
Args:
data: InputEndpoints namedtuple.
data: InputEndpoints namedtuple.
endpoints: OutputEndpoints namedtuple.
endpoints: OutputEndpoints namedtuple.
charset: A dictionary with mapping between character codes and
charset: A dictionary with mapping between character codes and
unicode
unicode
characters. Use the one provided by a dataset.charset.
characters. Use the one provided by a dataset.charset.
is_training: If True will create summary prefixes for training job,
is_training: If True will create summary prefixes for training job,
otherwise - for evaluation.
otherwise - for evaluation.
...
@@ -503,13 +669,14 @@ class Model(object):
...
@@ -503,13 +669,14 @@ class Model(object):
# tf.summary.text(sname('text/pr'), pr_text)
# tf.summary.text(sname('text/pr'), pr_text)
# gt_text = charset_mapper.get_text(data.labels[:max_outputs,:])
# gt_text = charset_mapper.get_text(data.labels[:max_outputs,:])
# tf.summary.text(sname('text/gt'), gt_text)
# tf.summary.text(sname('text/gt'), gt_text)
tf
.
summary
.
image
(
sname
(
'image'
),
data
.
images
,
max_outputs
=
max_outputs
)
tf
.
compat
.
v1
.
summary
.
image
(
sname
(
'image'
),
data
.
images
,
max_outputs
=
max_outputs
)
if
is_training
:
if
is_training
:
tf
.
summary
.
image
(
tf
.
compat
.
v1
.
summary
.
image
(
sname
(
'image/orig'
),
data
.
images_orig
,
max_outputs
=
max_outputs
)
sname
(
'image/orig'
),
data
.
images_orig
,
max_outputs
=
max_outputs
)
for
var
in
tf
.
trainable_variables
():
for
var
in
tf
.
compat
.
v1
.
trainable_variables
():
tf
.
summary
.
histogram
(
var
.
op
.
name
,
var
)
tf
.
compat
.
v1
.
summary
.
histogram
(
var
.
op
.
name
,
var
)
return
None
return
None
else
:
else
:
...
@@ -520,32 +687,36 @@ class Model(object):
...
@@ -520,32 +687,36 @@ class Model(object):
names_to_values
[
name
]
=
value_update_tuple
[
0
]
names_to_values
[
name
]
=
value_update_tuple
[
0
]
names_to_updates
[
name
]
=
value_update_tuple
[
1
]
names_to_updates
[
name
]
=
value_update_tuple
[
1
]
use_metric
(
'CharacterAccuracy'
,
use_metric
(
metrics
.
char_accuracy
(
'CharacterAccuracy'
,
endpoints
.
predicted_chars
,
metrics
.
char_accuracy
(
data
.
labels
,
endpoints
.
predicted_chars
,
streaming
=
True
,
data
.
labels
,
rej_char
=
self
.
_params
.
null_code
))
streaming
=
True
,
rej_char
=
self
.
_params
.
null_code
))
# Sequence accuracy computed by cutting sequence at the first null char
# Sequence accuracy computed by cutting sequence at the first null char
use_metric
(
'SequenceAccuracy'
,
use_metric
(
metrics
.
sequence_accuracy
(
'SequenceAccuracy'
,
endpoints
.
predicted_chars
,
metrics
.
sequence_accuracy
(
data
.
labels
,
endpoints
.
predicted_chars
,
streaming
=
True
,
data
.
labels
,
rej_char
=
self
.
_params
.
null_code
))
streaming
=
True
,
rej_char
=
self
.
_params
.
null_code
))
for
name
,
value
in
names_to_values
.
items
():
for
name
,
value
in
names_to_values
.
items
():
summary_name
=
'eval/'
+
name
summary_name
=
'eval/'
+
name
tf
.
summary
.
scalar
(
summary_name
,
tf
.
Print
(
value
,
[
value
],
summary_name
))
tf
.
compat
.
v1
.
summary
.
scalar
(
summary_name
,
tf
.
compat
.
v1
.
Print
(
value
,
[
value
],
summary_name
))
return
list
(
names_to_updates
.
values
())
return
list
(
names_to_updates
.
values
())
def
create_init_fn_to_restore
(
self
,
master_checkpoint
,
def
create_init_fn_to_restore
(
self
,
master_checkpoint
,
inception_checkpoint
=
None
):
inception_checkpoint
=
None
):
"""Creates an init operations to restore weights from various checkpoints.
"""Creates an init operations to restore weights from various checkpoints.
Args:
Args:
master_checkpoint: path to a checkpoint which contains all weights for
master_checkpoint: path to a checkpoint which contains all weights for
the
the
whole model.
whole model.
inception_checkpoint: path to a checkpoint which contains weights for the
inception_checkpoint: path to a checkpoint which contains weights for the
inception part only.
inception part only.
...
@@ -556,8 +727,8 @@ class Model(object):
...
@@ -556,8 +727,8 @@ class Model(object):
all_feed_dict
=
{}
all_feed_dict
=
{}
def
assign_from_checkpoint
(
variables
,
checkpoint
):
def
assign_from_checkpoint
(
variables
,
checkpoint
):
logging
.
info
(
'Request to re-store %d weights from %s'
,
logging
.
info
(
'Request to re-store %d weights from %s'
,
len
(
variables
),
len
(
variables
),
checkpoint
)
checkpoint
)
if
not
variables
:
if
not
variables
:
logging
.
error
(
'Can
\'
t find any variables to restore.'
)
logging
.
error
(
'Can
\'
t find any variables to restore.'
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
...
@@ -565,15 +736,18 @@ class Model(object):
...
@@ -565,15 +736,18 @@ class Model(object):
all_assign_ops
.
append
(
assign_op
)
all_assign_ops
.
append
(
assign_op
)
all_feed_dict
.
update
(
feed_dict
)
all_feed_dict
.
update
(
feed_dict
)
logging
.
info
(
'variables_to_restore:
\n
%s'
%
utils
.
variables_to_restore
().
keys
())
logging
.
info
(
'variables_to_restore:
\n
%s'
,
logging
.
info
(
'moving_average_variables:
\n
%s'
%
[
v
.
op
.
name
for
v
in
tf
.
moving_average_variables
()])
utils
.
variables_to_restore
().
keys
())
logging
.
info
(
'trainable_variables:
\n
%s'
%
[
v
.
op
.
name
for
v
in
tf
.
trainable_variables
()])
logging
.
info
(
'moving_average_variables:
\n
%s'
,
[
v
.
op
.
name
for
v
in
tf
.
compat
.
v1
.
moving_average_variables
()])
logging
.
info
(
'trainable_variables:
\n
%s'
,
[
v
.
op
.
name
for
v
in
tf
.
compat
.
v1
.
trainable_variables
()])
if
master_checkpoint
:
if
master_checkpoint
:
assign_from_checkpoint
(
utils
.
variables_to_restore
(),
master_checkpoint
)
assign_from_checkpoint
(
utils
.
variables_to_restore
(),
master_checkpoint
)
if
inception_checkpoint
:
if
inception_checkpoint
:
variables
=
utils
.
variables_to_restore
(
variables
=
utils
.
variables_to_restore
(
'AttentionOcr_v1/conv_tower_fn/INCE'
,
strip_scope
=
True
)
'AttentionOcr_v1/conv_tower_fn/INCE'
,
strip_scope
=
True
)
assign_from_checkpoint
(
variables
,
inception_checkpoint
)
assign_from_checkpoint
(
variables
,
inception_checkpoint
)
def
init_assign_fn
(
sess
):
def
init_assign_fn
(
sess
):
...
...
research/attention_ocr/python/model_export.py
0 → 100644
View file @
0cceabfc
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Converts existing checkpoint into a SavedModel.
Usage example:
python model_export.py \
--logtostderr --checkpoint=model.ckpt-399731 \
--export_dir=/tmp/attention_ocr_export
"""
import
os
import
tensorflow
as
tf
from
tensorflow
import
app
from
tensorflow.contrib
import
slim
from
tensorflow.python.platform
import
flags
import
common_flags
import
model_export_lib
FLAGS
=
flags
.
FLAGS
common_flags
.
define
()
flags
.
DEFINE_string
(
'export_dir'
,
None
,
'Directory to export model files to.'
)
flags
.
DEFINE_integer
(
'image_width'
,
None
,
'Image width used during training (or crop width if used)'
' If not set, the dataset default is used instead.'
)
flags
.
DEFINE_integer
(
'image_height'
,
None
,
'Image height used during training(or crop height if used)'
' If not set, the dataset default is used instead.'
)
flags
.
DEFINE_string
(
'work_dir'
,
'/tmp'
,
'A directory to store temporary files.'
)
flags
.
DEFINE_integer
(
'version_number'
,
1
,
'Version number of the model'
)
flags
.
DEFINE_bool
(
'export_for_serving'
,
True
,
'Whether the exported model accepts serialized tf.Example '
'protos as input'
)
def
get_checkpoint_path
():
"""Returns a path to a checkpoint based on specified commandline flags.
In order to specify a full path to a checkpoint use --checkpoint flag.
Alternatively, if --train_log_dir was specified it will return a path to the
most recent checkpoint.
Raises:
ValueError: in case it can't find a checkpoint.
Returns:
A string.
"""
if
FLAGS
.
checkpoint
:
return
FLAGS
.
checkpoint
else
:
model_save_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
train_log_dir
)
if
not
model_save_path
:
raise
ValueError
(
'Can
\'
t find a checkpoint in: %s'
%
FLAGS
.
train_log_dir
)
return
model_save_path
def
export_model
(
export_dir
,
export_for_serving
,
batch_size
=
None
,
crop_image_width
=
None
,
crop_image_height
=
None
):
"""Exports a model to the named directory.
Note that --datatset_name and --checkpoint are required and parsed by the
underlying module common_flags.
Args:
export_dir: The output dir where model is exported to.
export_for_serving: If True, expects a serialized image as input and attach
image normalization as part of exported graph.
batch_size: For non-serving export, the input batch_size needs to be
specified.
crop_image_width: Width of the input image. Uses the dataset default if
None.
crop_image_height: Height of the input image. Uses the dataset default if
None.
Returns:
Returns the model signature_def.
"""
# Dataset object used only to get all parameters for the model.
dataset
=
common_flags
.
create_dataset
(
split_name
=
'test'
)
model
=
common_flags
.
create_model
(
dataset
.
num_char_classes
,
dataset
.
max_sequence_length
,
dataset
.
num_of_views
,
dataset
.
null_code
,
charset
=
dataset
.
charset
)
dataset_image_height
,
dataset_image_width
,
image_depth
=
dataset
.
image_shape
# Add check for charmap file
if
not
os
.
path
.
exists
(
dataset
.
charset_file
):
raise
ValueError
(
'No charset defined at {}: export will fail'
.
format
(
dataset
.
charset
))
# Default to dataset dimensions, otherwise use provided dimensions.
image_width
=
crop_image_width
or
dataset_image_width
image_height
=
crop_image_height
or
dataset_image_height
if
export_for_serving
:
images_orig
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
string
,
shape
=
[
batch_size
],
name
=
'tf_example'
)
images_orig_float
=
model_export_lib
.
generate_tfexample_image
(
images_orig
,
image_height
,
image_width
,
image_depth
,
name
=
'float_images'
)
else
:
images_shape
=
(
batch_size
,
image_height
,
image_width
,
image_depth
)
images_orig
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
uint8
,
shape
=
images_shape
,
name
=
'original_image'
)
images_orig_float
=
tf
.
image
.
convert_image_dtype
(
images_orig
,
dtype
=
tf
.
float32
,
name
=
'float_images'
)
endpoints
=
model
.
create_base
(
images_orig_float
,
labels_one_hot
=
None
)
sess
=
tf
.
compat
.
v1
.
Session
()
saver
=
tf
.
compat
.
v1
.
train
.
Saver
(
slim
.
get_variables_to_restore
(),
sharded
=
True
)
saver
.
restore
(
sess
,
get_checkpoint_path
())
tf
.
compat
.
v1
.
logging
.
info
(
'Model restored successfully.'
)
# Create model signature.
if
export_for_serving
:
input_tensors
=
{
tf
.
saved_model
.
CLASSIFY_INPUTS
:
images_orig
}
else
:
input_tensors
=
{
'images'
:
images_orig
}
signature_inputs
=
model_export_lib
.
build_tensor_info
(
input_tensors
)
# NOTE: Tensors 'image_float' and 'chars_logit' are used by the inference
# or to compute saliency maps.
output_tensors
=
{
'images_float'
:
images_orig_float
,
'predictions'
:
endpoints
.
predicted_chars
,
'scores'
:
endpoints
.
predicted_scores
,
'chars_logit'
:
endpoints
.
chars_logit
,
'predicted_length'
:
endpoints
.
predicted_length
,
'predicted_text'
:
endpoints
.
predicted_text
,
'predicted_conf'
:
endpoints
.
predicted_conf
,
'normalized_seq_conf'
:
endpoints
.
normalized_seq_conf
}
for
i
,
t
in
enumerate
(
model_export_lib
.
attention_ocr_attention_masks
(
dataset
.
max_sequence_length
)):
output_tensors
[
'attention_mask_%d'
%
i
]
=
t
signature_outputs
=
model_export_lib
.
build_tensor_info
(
output_tensors
)
signature_def
=
tf
.
compat
.
v1
.
saved_model
.
signature_def_utils
.
build_signature_def
(
signature_inputs
,
signature_outputs
,
tf
.
saved_model
.
CLASSIFY_METHOD_NAME
)
# Save model.
builder
=
tf
.
compat
.
v1
.
saved_model
.
builder
.
SavedModelBuilder
(
export_dir
)
builder
.
add_meta_graph_and_variables
(
sess
,
[
tf
.
saved_model
.
SERVING
],
signature_def_map
=
{
tf
.
saved_model
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
:
signature_def
},
main_op
=
tf
.
compat
.
v1
.
tables_initializer
(),
strip_default_attrs
=
True
)
builder
.
save
()
tf
.
compat
.
v1
.
logging
.
info
(
'Model has been exported to %s'
%
export_dir
)
return
signature_def
def
main
(
unused_argv
):
if
os
.
path
.
exists
(
FLAGS
.
export_dir
):
raise
ValueError
(
'export_dir already exists: exporting will fail'
)
export_model
(
FLAGS
.
export_dir
,
FLAGS
.
export_for_serving
,
FLAGS
.
batch_size
,
FLAGS
.
image_width
,
FLAGS
.
image_height
)
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'dataset_name'
)
flags
.
mark_flag_as_required
(
'export_dir'
)
app
.
run
(
main
)
research/attention_ocr/python/model_export_lib.py
0 → 100644
View file @
0cceabfc
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for exporting Attention OCR model."""
import
tensorflow
as
tf
# Function borrowed from research/object_detection/core/preprocessor.py
def
normalize_image
(
image
,
original_minval
,
original_maxval
,
target_minval
,
target_maxval
):
"""Normalizes pixel values in the image.
Moves the pixel values from the current [original_minval, original_maxval]
range to a the [target_minval, target_maxval] range.
Args:
image: rank 3 float32 tensor containing 1 image -> [height, width,
channels].
original_minval: current image minimum value.
original_maxval: current image maximum value.
target_minval: target image minimum value.
target_maxval: target image maximum value.
Returns:
image: image which is the same shape as input image.
"""
with
tf
.
compat
.
v1
.
name_scope
(
'NormalizeImage'
,
values
=
[
image
]):
original_minval
=
float
(
original_minval
)
original_maxval
=
float
(
original_maxval
)
target_minval
=
float
(
target_minval
)
target_maxval
=
float
(
target_maxval
)
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
float32
)
image
=
tf
.
subtract
(
image
,
original_minval
)
image
=
tf
.
multiply
(
image
,
(
target_maxval
-
target_minval
)
/
(
original_maxval
-
original_minval
))
image
=
tf
.
add
(
image
,
target_minval
)
return
image
def
generate_tfexample_image
(
input_example_strings
,
image_height
,
image_width
,
image_channels
,
name
=
None
):
"""Parses a 1D tensor of serialized tf.Example protos and returns image batch.
Args:
input_example_strings: A 1-Dimensional tensor of size [batch_size] and type
tf.string containing a serialized Example proto per image.
image_height: First image dimension.
image_width: Second image dimension.
image_channels: Third image dimension.
name: optional tensor name.
Returns:
A tensor with shape [batch_size, height, width, channels] of type float32
with values in the range [0..1]
"""
batch_size
=
tf
.
shape
(
input
=
input_example_strings
)[
0
]
images_shape
=
tf
.
stack
(
[
batch_size
,
image_height
,
image_width
,
image_channels
])
tf_example_image_key
=
'image/encoded'
feature_configs
=
{
tf_example_image_key
:
tf
.
io
.
FixedLenFeature
(
image_height
*
image_width
*
image_channels
,
dtype
=
tf
.
float32
)
}
feature_tensors
=
tf
.
io
.
parse_example
(
serialized
=
input_example_strings
,
features
=
feature_configs
)
float_images
=
tf
.
reshape
(
normalize_image
(
feature_tensors
[
tf_example_image_key
],
original_minval
=
0.0
,
original_maxval
=
255.0
,
target_minval
=
0.0
,
target_maxval
=
1.0
),
images_shape
,
name
=
name
)
return
float_images
def
attention_ocr_attention_masks
(
num_characters
):
# TODO(gorban): use tensors directly after replacing LSTM unroll methods.
prefix
=
(
'AttentionOcr_v1/'
'sequence_logit_fn/SQLR/LSTM/attention_decoder/Attention_0'
)
names
=
[
'%s/Softmax:0'
%
(
prefix
)]
for
i
in
range
(
1
,
num_characters
):
names
+=
[
'%s_%d/Softmax:0'
%
(
prefix
,
i
)]
return
[
tf
.
compat
.
v1
.
get_default_graph
().
get_tensor_by_name
(
n
)
for
n
in
names
]
def
build_tensor_info
(
tensor_dict
):
return
{
k
:
tf
.
compat
.
v1
.
saved_model
.
utils
.
build_tensor_info
(
t
)
for
k
,
t
in
tensor_dict
.
items
()
}
research/attention_ocr/python/model_export_test.py
0 → 100644
View file @
0cceabfc
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model_export."""
import
os
import
numpy
as
np
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
import
common_flags
import
model_export
_CHECKPOINT
=
'model.ckpt-399731'
_CHECKPOINT_URL
=
(
'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz'
)
def
_clean_up
():
tf
.
io
.
gfile
.
rmtree
(
tf
.
compat
.
v1
.
test
.
get_temp_dir
())
def
_create_tf_example_string
(
image
):
"""Create a serialized tf.Example proto for feeding the model."""
example
=
tf
.
train
.
Example
()
example
.
features
.
feature
[
'image/encoded'
].
float_list
.
value
.
extend
(
list
(
np
.
reshape
(
image
,
(
-
1
))))
return
example
.
SerializeToString
()
class
AttentionOcrExportTest
(
tf
.
test
.
TestCase
):
"""Tests for model_export.export_model."""
def
setUp
(
self
):
for
suffix
in
[
'.meta'
,
'.index'
,
'.data-00000-of-00001'
]:
filename
=
_CHECKPOINT
+
suffix
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
filename
),
msg
=
'Missing checkpoint file %s. '
'Please download and extract it from %s'
%
(
filename
,
_CHECKPOINT_URL
))
tf
.
flags
.
FLAGS
.
dataset_name
=
'fsns'
tf
.
flags
.
FLAGS
.
checkpoint
=
_CHECKPOINT
tf
.
flags
.
FLAGS
.
dataset_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'datasets/testdata/fsns'
)
tf
.
test
.
TestCase
.
setUp
(
self
)
_clean_up
()
self
.
export_dir
=
os
.
path
.
join
(
tf
.
compat
.
v1
.
test
.
get_temp_dir
(),
'exported_model'
)
self
.
minimal_output_signature
=
{
'predictions'
:
'AttentionOcr_v1/predicted_chars:0'
,
'scores'
:
'AttentionOcr_v1/predicted_scores:0'
,
'predicted_length'
:
'AttentionOcr_v1/predicted_length:0'
,
'predicted_text'
:
'AttentionOcr_v1/predicted_text:0'
,
'predicted_conf'
:
'AttentionOcr_v1/predicted_conf:0'
,
'normalized_seq_conf'
:
'AttentionOcr_v1/normalized_seq_conf:0'
}
def
create_input_feed
(
self
,
graph_def
,
serving
):
"""Returns the input feed for the model.
Creates random images, according to the size specified by dataset_name,
format it in the correct way depending on whether the model was exported
for serving, and return the correctly keyed feed_dict for inference.
Args:
graph_def: Graph definition of the loaded model.
serving: Whether the model was exported for Serving.
Returns:
The feed_dict suitable for model inference.
"""
# Creates a dataset based on FLAGS.dataset_name.
self
.
dataset
=
common_flags
.
create_dataset
(
'test'
)
# Create some random images to test inference for any dataset.
self
.
images
=
{
'img1'
:
np
.
random
.
uniform
(
low
=
64
,
high
=
192
,
size
=
self
.
dataset
.
image_shape
).
astype
(
'uint8'
),
'img2'
:
np
.
random
.
uniform
(
low
=
32
,
high
=
224
,
size
=
self
.
dataset
.
image_shape
).
astype
(
'uint8'
),
}
signature_def
=
graph_def
.
signature_def
[
tf
.
saved_model
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
]
if
serving
:
input_name
=
signature_def
.
inputs
[
tf
.
saved_model
.
CLASSIFY_INPUTS
].
name
# Model for serving takes input: inputs['inputs'] = 'tf_example:0'
feed_dict
=
{
input_name
:
[
_create_tf_example_string
(
self
.
images
[
'img1'
]),
_create_tf_example_string
(
self
.
images
[
'img2'
])
]
}
else
:
input_name
=
signature_def
.
inputs
[
'images'
].
name
# Model for direct use takes input: inputs['images'] = 'original_image:0'
feed_dict
=
{
input_name
:
np
.
stack
([
self
.
images
[
'img1'
],
self
.
images
[
'img2'
]])
}
return
feed_dict
def
verify_export_load_and_inference
(
self
,
export_for_serving
=
False
):
"""Verify exported model can be loaded and inference can run successfully.
This function will load the exported model in self.export_dir, then create
some fake images according to the specification of FLAGS.dataset_name.
It then feeds the input through the model, and verify the minimal set of
output signatures are present.
Note: Model and dataset creation in the underlying library depends on the
following commandline flags:
FLAGS.dataset_name
Args:
export_for_serving: True if the model was exported for Serving. This
affects how input is fed into the model.
"""
tf
.
compat
.
v1
.
reset_default_graph
()
sess
=
tf
.
compat
.
v1
.
Session
()
graph_def
=
tf
.
compat
.
v1
.
saved_model
.
loader
.
load
(
sess
=
sess
,
tags
=
[
tf
.
saved_model
.
SERVING
],
export_dir
=
self
.
export_dir
)
feed_dict
=
self
.
create_input_feed
(
graph_def
,
export_for_serving
)
results
=
sess
.
run
(
self
.
minimal_output_signature
,
feed_dict
=
feed_dict
)
out_shape
=
(
2
,)
self
.
assertEqual
(
np
.
shape
(
results
[
'predicted_conf'
]),
out_shape
)
self
.
assertEqual
(
np
.
shape
(
results
[
'predicted_text'
]),
out_shape
)
self
.
assertEqual
(
np
.
shape
(
results
[
'predicted_length'
]),
out_shape
)
self
.
assertEqual
(
np
.
shape
(
results
[
'normalized_seq_conf'
]),
out_shape
)
out_shape
=
(
2
,
self
.
dataset
.
max_sequence_length
)
self
.
assertEqual
(
np
.
shape
(
results
[
'scores'
]),
out_shape
)
self
.
assertEqual
(
np
.
shape
(
results
[
'predictions'
]),
out_shape
)
@
flagsaver
.
flagsaver
def
test_fsns_export_for_serving_and_load_inference
(
self
):
model_export
.
export_model
(
self
.
export_dir
,
True
)
self
.
verify_export_load_and_inference
(
True
)
@
flagsaver
.
flagsaver
def
test_fsns_export_and_load_inference
(
self
):
model_export
.
export_model
(
self
.
export_dir
,
False
,
batch_size
=
2
)
self
.
verify_export_load_and_inference
(
False
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/attention_ocr/python/model_test.py
View file @
0cceabfc
...
@@ -12,11 +12,10 @@
...
@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Tests for the model."""
"""Tests for the model."""
import
string
import
numpy
as
np
import
numpy
as
np
import
string
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib
import
slim
from
tensorflow.contrib
import
slim
...
@@ -32,6 +31,7 @@ def create_fake_charset(num_char_classes):
...
@@ -32,6 +31,7 @@ def create_fake_charset(num_char_classes):
class
ModelTest
(
tf
.
test
.
TestCase
):
class
ModelTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
tf
.
test
.
TestCase
.
setUp
(
self
)
tf
.
test
.
TestCase
.
setUp
(
self
)
...
@@ -51,18 +51,21 @@ class ModelTest(tf.test.TestCase):
...
@@ -51,18 +51,21 @@ class ModelTest(tf.test.TestCase):
self
.
chars_logit_shape
=
(
self
.
batch_size
,
self
.
seq_length
,
self
.
chars_logit_shape
=
(
self
.
batch_size
,
self
.
seq_length
,
self
.
num_char_classes
)
self
.
num_char_classes
)
self
.
length_logit_shape
=
(
self
.
batch_size
,
self
.
seq_length
+
1
)
self
.
length_logit_shape
=
(
self
.
batch_size
,
self
.
seq_length
+
1
)
# Placeholder knows image dimensions, but not batch size.
self
.
input_images
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
self
.
image_height
,
self
.
image_width
,
3
),
name
=
'input_node'
)
self
.
initialize_fakes
()
self
.
initialize_fakes
()
def
initialize_fakes
(
self
):
def
initialize_fakes
(
self
):
self
.
images_shape
=
(
self
.
batch_size
,
self
.
image_height
,
self
.
image_width
,
self
.
images_shape
=
(
self
.
batch_size
,
self
.
image_height
,
self
.
image_width
,
3
)
3
)
self
.
fake_images
=
tf
.
constant
(
self
.
fake_images
=
self
.
rng
.
randint
(
self
.
rng
.
randint
(
low
=
0
,
high
=
255
,
low
=
0
,
high
=
255
,
size
=
self
.
images_shape
).
astype
(
'float32'
)
size
=
self
.
images_shape
).
astype
(
'float32'
),
self
.
fake_conv_tower_np
=
self
.
rng
.
randn
(
*
self
.
conv_tower_shape
).
astype
(
name
=
'input_node'
)
'float32'
)
self
.
fake_conv_tower_np
=
self
.
rng
.
randn
(
*
self
.
conv_tower_shape
).
astype
(
'float32'
)
self
.
fake_conv_tower
=
tf
.
constant
(
self
.
fake_conv_tower_np
)
self
.
fake_conv_tower
=
tf
.
constant
(
self
.
fake_conv_tower_np
)
self
.
fake_logits
=
tf
.
constant
(
self
.
fake_logits
=
tf
.
constant
(
self
.
rng
.
randn
(
*
self
.
chars_logit_shape
).
astype
(
'float32'
))
self
.
rng
.
randn
(
*
self
.
chars_logit_shape
).
astype
(
'float32'
))
...
@@ -74,33 +77,44 @@ class ModelTest(tf.test.TestCase):
...
@@ -74,33 +77,44 @@ class ModelTest(tf.test.TestCase):
def
create_model
(
self
,
charset
=
None
):
def
create_model
(
self
,
charset
=
None
):
return
model
.
Model
(
return
model
.
Model
(
self
.
num_char_classes
,
self
.
seq_length
,
num_views
=
4
,
null_code
=
62
,
self
.
num_char_classes
,
self
.
seq_length
,
num_views
=
4
,
null_code
=
62
,
charset
=
charset
)
charset
=
charset
)
def
test_char_related_shapes
(
self
):
def
test_char_related_shapes
(
self
):
ocr_model
=
self
.
create_model
()
charset
=
create_fake_charset
(
self
.
num_char_classes
)
ocr_model
=
self
.
create_model
(
charset
=
charset
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
endpoints_tf
=
ocr_model
.
create_base
(
endpoints_tf
=
ocr_model
.
create_base
(
images
=
self
.
fake_images
,
labels_one_hot
=
None
)
images
=
self
.
input_images
,
labels_one_hot
=
None
)
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
sess
.
run
(
tf
.
global_variables_initializer
())
tf
.
compat
.
v1
.
tables_initializer
().
run
()
endpoints
=
sess
.
run
(
endpoints_tf
)
endpoints
=
sess
.
run
(
endpoints_tf
,
feed_dict
=
{
self
.
input_images
:
self
.
fake_images
})
self
.
assertEqual
((
self
.
batch_size
,
self
.
seq_length
,
self
.
num_char_classes
),
endpoints
.
chars_logit
.
shape
)
self
.
assertEqual
(
self
.
assertEqual
((
self
.
batch_size
,
self
.
seq_length
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
num_char_classes
),
self
.
num_char_classes
),
endpoints
.
chars_log_prob
.
shape
)
endpoints
.
chars_logit
.
shape
)
self
.
assertEqual
(
(
self
.
batch_size
,
self
.
seq_length
,
self
.
num_char_classes
),
endpoints
.
chars_log_prob
.
shape
)
self
.
assertEqual
((
self
.
batch_size
,
self
.
seq_length
),
self
.
assertEqual
((
self
.
batch_size
,
self
.
seq_length
),
endpoints
.
predicted_chars
.
shape
)
endpoints
.
predicted_chars
.
shape
)
self
.
assertEqual
((
self
.
batch_size
,
self
.
seq_length
),
self
.
assertEqual
((
self
.
batch_size
,
self
.
seq_length
),
endpoints
.
predicted_scores
.
shape
)
endpoints
.
predicted_scores
.
shape
)
self
.
assertEqual
((
self
.
batch_size
,),
endpoints
.
predicted_text
.
shape
)
self
.
assertEqual
((
self
.
batch_size
,),
endpoints
.
predicted_conf
.
shape
)
self
.
assertEqual
((
self
.
batch_size
,),
endpoints
.
normalized_seq_conf
.
shape
)
def
test_predicted_scores_are_within_range
(
self
):
def
test_predicted_scores_are_within_range
(
self
):
ocr_model
=
self
.
create_model
()
ocr_model
=
self
.
create_model
()
_
,
_
,
scores
=
ocr_model
.
char_predictions
(
self
.
fake_logits
)
_
,
_
,
scores
=
ocr_model
.
char_predictions
(
self
.
fake_logits
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
scores_np
=
sess
.
run
(
scores
)
scores_np
=
sess
.
run
(
scores
,
feed_dict
=
{
self
.
input_images
:
self
.
fake_images
})
values_in_range
=
(
scores_np
>=
0.0
)
&
(
scores_np
<=
1.0
)
values_in_range
=
(
scores_np
>=
0.0
)
&
(
scores_np
<=
1.0
)
self
.
assertTrue
(
self
.
assertTrue
(
...
@@ -111,10 +125,11 @@ class ModelTest(tf.test.TestCase):
...
@@ -111,10 +125,11 @@ class ModelTest(tf.test.TestCase):
def
test_conv_tower_shape
(
self
):
def
test_conv_tower_shape
(
self
):
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
ocr_model
=
self
.
create_model
()
ocr_model
=
self
.
create_model
()
conv_tower
=
ocr_model
.
conv_tower_fn
(
self
.
fake
_images
)
conv_tower
=
ocr_model
.
conv_tower_fn
(
self
.
input
_images
)
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
conv_tower_np
=
sess
.
run
(
conv_tower
)
conv_tower_np
=
sess
.
run
(
conv_tower
,
feed_dict
=
{
self
.
input_images
:
self
.
fake_images
})
self
.
assertEqual
(
self
.
conv_tower_shape
,
conv_tower_np
.
shape
)
self
.
assertEqual
(
self
.
conv_tower_shape
,
conv_tower_np
.
shape
)
...
@@ -124,11 +139,12 @@ class ModelTest(tf.test.TestCase):
...
@@ -124,11 +139,12 @@ class ModelTest(tf.test.TestCase):
# updates, gradients and variances. It also depends on the type of used
# updates, gradients and variances. It also depends on the type of used
# optimizer.
# optimizer.
ocr_model
=
self
.
create_model
()
ocr_model
=
self
.
create_model
()
ocr_model
.
create_base
(
images
=
self
.
fake
_images
,
labels_one_hot
=
None
)
ocr_model
.
create_base
(
images
=
self
.
input
_images
,
labels_one_hot
=
None
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
tfprof_root
=
tf
.
profiler
.
profile
(
tfprof_root
=
tf
.
compat
.
v1
.
profiler
.
profile
(
sess
.
graph
,
sess
.
graph
,
options
=
tf
.
profiler
.
ProfileOptionBuilder
.
trainable_variables_parameter
())
options
=
tf
.
compat
.
v1
.
profiler
.
ProfileOptionBuilder
.
trainable_variables_parameter
())
model_size_bytes
=
4
*
tfprof_root
.
total_parameters
model_size_bytes
=
4
*
tfprof_root
.
total_parameters
self
.
assertLess
(
model_size_bytes
,
1
*
2
**
30
)
self
.
assertLess
(
model_size_bytes
,
1
*
2
**
30
)
...
@@ -147,9 +163,9 @@ class ModelTest(tf.test.TestCase):
...
@@ -147,9 +163,9 @@ class ModelTest(tf.test.TestCase):
summaries
=
ocr_model
.
create_summaries
(
summaries
=
ocr_model
.
create_summaries
(
data
,
endpoints
,
charset
,
is_training
=
False
)
data
,
endpoints
,
charset
,
is_training
=
False
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
local_variables_initializer
())
tf
.
tables_initializer
().
run
()
tf
.
compat
.
v1
.
tables_initializer
().
run
()
sess
.
run
(
summaries
)
# just check it is runnable
sess
.
run
(
summaries
)
# just check it is runnable
def
test_sequence_loss_function_without_label_smoothing
(
self
):
def
test_sequence_loss_function_without_label_smoothing
(
self
):
...
@@ -158,7 +174,7 @@ class ModelTest(tf.test.TestCase):
...
@@ -158,7 +174,7 @@ class ModelTest(tf.test.TestCase):
loss
=
model
.
sequence_loss_fn
(
self
.
fake_logits
,
self
.
fake_labels
)
loss
=
model
.
sequence_loss_fn
(
self
.
fake_logits
,
self
.
fake_labels
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
loss_np
=
sess
.
run
(
loss
)
loss_np
=
sess
.
run
(
loss
,
feed_dict
=
{
self
.
input_images
:
self
.
fake_images
}
)
# This test checks that the loss function is 'runnable'.
# This test checks that the loss function is 'runnable'.
self
.
assertEqual
(
loss_np
.
shape
,
tuple
())
self
.
assertEqual
(
loss_np
.
shape
,
tuple
())
...
@@ -172,19 +188,21 @@ class ModelTest(tf.test.TestCase):
...
@@ -172,19 +188,21 @@ class ModelTest(tf.test.TestCase):
Returns:
Returns:
a list of tensors with encoded image coordinates in them.
a list of tensors with encoded image coordinates in them.
"""
"""
batch_size
,
h
,
w
,
_
=
net
.
shape
.
as_list
()
batch_size
=
tf
.
shape
(
input
=
net
)[
0
]
_
,
h
,
w
,
_
=
net
.
shape
.
as_list
()
h_loc
=
[
h_loc
=
[
tf
.
tile
(
tf
.
tile
(
tf
.
reshape
(
tf
.
reshape
(
tf
.
contrib
.
layers
.
one_hot_encoding
(
tf
.
contrib
.
layers
.
one_hot_encoding
(
tf
.
constant
([
i
]),
num_classes
=
h
),
[
h
,
1
]),
[
1
,
w
])
tf
.
constant
([
i
]),
num_classes
=
h
),
[
h
,
1
]),
[
1
,
w
])
for
i
in
range
(
h
)
for
i
in
range
(
h
)
]
]
h_loc
=
tf
.
concat
([
tf
.
expand_dims
(
t
,
2
)
for
t
in
h_loc
],
2
)
h_loc
=
tf
.
concat
([
tf
.
expand_dims
(
t
,
2
)
for
t
in
h_loc
],
2
)
w_loc
=
[
w_loc
=
[
tf
.
tile
(
tf
.
tile
(
tf
.
contrib
.
layers
.
one_hot_encoding
(
tf
.
constant
([
i
]),
num_classes
=
w
),
tf
.
contrib
.
layers
.
one_hot_encoding
(
[
h
,
1
])
for
i
in
range
(
w
)
tf
.
constant
([
i
]),
num_classes
=
w
),
[
h
,
1
])
for
i
in
range
(
w
)
]
]
w_loc
=
tf
.
concat
([
tf
.
expand_dims
(
t
,
2
)
for
t
in
w_loc
],
2
)
w_loc
=
tf
.
concat
([
tf
.
expand_dims
(
t
,
2
)
for
t
in
w_loc
],
2
)
loc
=
tf
.
concat
([
h_loc
,
w_loc
],
2
)
loc
=
tf
.
concat
([
h_loc
,
w_loc
],
2
)
...
@@ -197,11 +215,12 @@ class ModelTest(tf.test.TestCase):
...
@@ -197,11 +215,12 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
self
.
fake_conv_tower
)
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
self
.
fake_conv_tower
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
)
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
,
feed_dict
=
{
self
.
input_images
:
self
.
fake_images
})
batch_size
,
height
,
width
,
feature_size
=
self
.
conv_tower_shape
batch_size
,
height
,
width
,
feature_size
=
self
.
conv_tower_shape
self
.
assertEqual
(
conv_w_coords
.
shape
,
(
batch_size
,
height
,
width
,
self
.
assertEqual
(
conv_w_coords
.
shape
,
feature_size
+
height
+
width
))
(
batch_size
,
height
,
width
,
feature_size
+
height
+
width
))
def
test_disabled_coordinate_encoding_returns_features_unchanged
(
self
):
def
test_disabled_coordinate_encoding_returns_features_unchanged
(
self
):
model
=
self
.
create_model
()
model
=
self
.
create_model
()
...
@@ -209,7 +228,8 @@ class ModelTest(tf.test.TestCase):
...
@@ -209,7 +228,8 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
self
.
fake_conv_tower
)
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
self
.
fake_conv_tower
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
)
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
,
feed_dict
=
{
self
.
input_images
:
self
.
fake_images
})
self
.
assertAllEqual
(
conv_w_coords
,
self
.
fake_conv_tower_np
)
self
.
assertAllEqual
(
conv_w_coords
,
self
.
fake_conv_tower_np
)
...
@@ -221,7 +241,8 @@ class ModelTest(tf.test.TestCase):
...
@@ -221,7 +241,8 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
fake_conv_tower
)
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
fake_conv_tower
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
)
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
,
feed_dict
=
{
self
.
input_images
:
self
.
fake_images
})
# Original features
# Original features
self
.
assertAllEqual
(
conv_w_coords
[
0
,
:,
:,
:
4
],
self
.
assertAllEqual
(
conv_w_coords
[
0
,
:,
:,
:
4
],
...
@@ -252,8 +273,8 @@ class ModelTest(tf.test.TestCase):
...
@@ -252,8 +273,8 @@ class ModelTest(tf.test.TestCase):
endpoints_tf
=
ocr_model
.
create_base
(
endpoints_tf
=
ocr_model
.
create_base
(
images
=
self
.
fake_images
,
labels_one_hot
=
None
)
images
=
self
.
fake_images
,
labels_one_hot
=
None
)
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
tf
.
tables_initializer
().
run
()
tf
.
compat
.
v1
.
tables_initializer
().
run
()
endpoints
=
sess
.
run
(
endpoints_tf
)
endpoints
=
sess
.
run
(
endpoints_tf
)
self
.
assertEqual
(
endpoints
.
predicted_text
.
shape
,
(
self
.
batch_size
,))
self
.
assertEqual
(
endpoints
.
predicted_text
.
shape
,
(
self
.
batch_size
,))
...
@@ -261,14 +282,15 @@ class ModelTest(tf.test.TestCase):
...
@@ -261,14 +282,15 @@ class ModelTest(tf.test.TestCase):
class
CharsetMapperTest
(
tf
.
test
.
TestCase
):
class
CharsetMapperTest
(
tf
.
test
.
TestCase
):
def
test_text_corresponds_to_ids
(
self
):
def
test_text_corresponds_to_ids
(
self
):
charset
=
create_fake_charset
(
36
)
charset
=
create_fake_charset
(
36
)
ids
=
tf
.
constant
(
ids
=
tf
.
constant
(
[[
17
,
14
,
21
,
21
,
24
],
[
32
,
24
,
27
,
21
,
13
]],
[[
17
,
14
,
21
,
21
,
24
],
[
32
,
24
,
27
,
21
,
13
]],
dtype
=
tf
.
int64
)
dtype
=
tf
.
int64
)
charset_mapper
=
model
.
CharsetMapper
(
charset
)
charset_mapper
=
model
.
CharsetMapper
(
charset
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
tf
.
tables_initializer
().
run
()
tf
.
compat
.
v1
.
tables_initializer
().
run
()
text
=
sess
.
run
(
charset_mapper
.
get_text
(
ids
))
text
=
sess
.
run
(
charset_mapper
.
get_text
(
ids
))
self
.
assertAllEqual
(
text
,
[
b
'hello'
,
b
'world'
])
self
.
assertAllEqual
(
text
,
[
b
'hello'
,
b
'world'
])
...
...
research/attention_ocr/python/sequence_layers.py
View file @
0cceabfc
...
@@ -111,12 +111,12 @@ class SequenceLayerBase(object):
...
@@ -111,12 +111,12 @@ class SequenceLayerBase(object):
self
.
_mparams
=
method_params
self
.
_mparams
=
method_params
self
.
_net
=
net
self
.
_net
=
net
self
.
_labels_one_hot
=
labels_one_hot
self
.
_labels_one_hot
=
labels_one_hot
self
.
_batch_size
=
net
.
get_shape
().
dims
[
0
].
value
self
.
_batch_size
=
tf
.
shape
(
input
=
net
)[
0
]
# Initialize parameters for char logits which will be computed on the fly
# Initialize parameters for char logits which will be computed on the fly
# inside an LSTM decoder.
# inside an LSTM decoder.
self
.
_char_logits
=
{}
self
.
_char_logits
=
{}
regularizer
=
slim
.
l2_
regularizer
(
self
.
_mparams
.
weight_decay
)
regularizer
=
tf
.
keras
.
regularizer
s
.
l2
(
0.5
*
(
self
.
_mparams
.
weight_decay
)
)
self
.
_softmax_w
=
slim
.
model_variable
(
self
.
_softmax_w
=
slim
.
model_variable
(
'softmax_w'
,
'softmax_w'
,
[
self
.
_mparams
.
num_lstm_units
,
self
.
_params
.
num_char_classes
],
[
self
.
_mparams
.
num_lstm_units
,
self
.
_params
.
num_char_classes
],
...
@@ -124,7 +124,7 @@ class SequenceLayerBase(object):
...
@@ -124,7 +124,7 @@ class SequenceLayerBase(object):
regularizer
=
regularizer
)
regularizer
=
regularizer
)
self
.
_softmax_b
=
slim
.
model_variable
(
self
.
_softmax_b
=
slim
.
model_variable
(
'softmax_b'
,
[
self
.
_params
.
num_char_classes
],
'softmax_b'
,
[
self
.
_params
.
num_char_classes
],
initializer
=
tf
.
zeros_initializer
(),
initializer
=
tf
.
compat
.
v1
.
zeros_initializer
(),
regularizer
=
regularizer
)
regularizer
=
regularizer
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -203,8 +203,8 @@ class SequenceLayerBase(object):
...
@@ -203,8 +203,8 @@ class SequenceLayerBase(object):
A tensor with shape [batch_size, num_char_classes]
A tensor with shape [batch_size, num_char_classes]
"""
"""
if
char_index
not
in
self
.
_char_logits
:
if
char_index
not
in
self
.
_char_logits
:
self
.
_char_logits
[
char_index
]
=
tf
.
nn
.
xw_plus_b
(
inputs
,
self
.
_softmax_w
,
self
.
_char_logits
[
char_index
]
=
tf
.
compat
.
v1
.
nn
.
xw_plus_b
(
inputs
,
self
.
_softmax_w
,
self
.
_softmax_b
)
self
.
_softmax_b
)
return
self
.
_char_logits
[
char_index
]
return
self
.
_char_logits
[
char_index
]
def
char_one_hot
(
self
,
logit
):
def
char_one_hot
(
self
,
logit
):
...
@@ -216,7 +216,7 @@ class SequenceLayerBase(object):
...
@@ -216,7 +216,7 @@ class SequenceLayerBase(object):
Returns:
Returns:
A tensor with shape [batch_size, num_char_classes]
A tensor with shape [batch_size, num_char_classes]
"""
"""
prediction
=
tf
.
argmax
(
logit
,
axis
=
1
)
prediction
=
tf
.
argmax
(
input
=
logit
,
axis
=
1
)
return
slim
.
one_hot_encoding
(
prediction
,
self
.
_params
.
num_char_classes
)
return
slim
.
one_hot_encoding
(
prediction
,
self
.
_params
.
num_char_classes
)
def
get_input
(
self
,
prev
,
i
):
def
get_input
(
self
,
prev
,
i
):
...
@@ -244,10 +244,10 @@ class SequenceLayerBase(object):
...
@@ -244,10 +244,10 @@ class SequenceLayerBase(object):
Returns:
Returns:
A tensor with shape [batch_size, seq_length, num_char_classes].
A tensor with shape [batch_size, seq_length, num_char_classes].
"""
"""
with
tf
.
variable_scope
(
'LSTM'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'LSTM'
):
first_label
=
self
.
get_input
(
prev
=
None
,
i
=
0
)
first_label
=
self
.
get_input
(
prev
=
None
,
i
=
0
)
decoder_inputs
=
[
first_label
]
+
[
None
]
*
(
self
.
_params
.
seq_length
-
1
)
decoder_inputs
=
[
first_label
]
+
[
None
]
*
(
self
.
_params
.
seq_length
-
1
)
lstm_cell
=
tf
.
co
ntrib
.
rnn
.
LSTMCell
(
lstm_cell
=
tf
.
co
mpat
.
v1
.
nn
.
rnn_cell
.
LSTMCell
(
self
.
_mparams
.
num_lstm_units
,
self
.
_mparams
.
num_lstm_units
,
use_peepholes
=
False
,
use_peepholes
=
False
,
cell_clip
=
self
.
_mparams
.
lstm_state_clip_value
,
cell_clip
=
self
.
_mparams
.
lstm_state_clip_value
,
...
@@ -259,9 +259,9 @@ class SequenceLayerBase(object):
...
@@ -259,9 +259,9 @@ class SequenceLayerBase(object):
loop_function
=
self
.
get_input
,
loop_function
=
self
.
get_input
,
cell
=
lstm_cell
)
cell
=
lstm_cell
)
with
tf
.
variable_scope
(
'logits'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'logits'
):
logits_list
=
[
logits_list
=
[
tf
.
expand_dims
(
self
.
char_logit
(
logit
,
i
),
dim
=
1
)
tf
.
expand_dims
(
self
.
char_logit
(
logit
,
i
),
axis
=
1
)
for
i
,
logit
in
enumerate
(
lstm_outputs
)
for
i
,
logit
in
enumerate
(
lstm_outputs
)
]
]
...
@@ -275,7 +275,7 @@ class NetSlice(SequenceLayerBase):
...
@@ -275,7 +275,7 @@ class NetSlice(SequenceLayerBase):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
NetSlice
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
NetSlice
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
_zero_label
=
tf
.
zeros
(
self
.
_zero_label
=
tf
.
zeros
(
[
self
.
_batch_size
,
self
.
_params
.
num_char_classes
])
tf
.
stack
(
[
self
.
_batch_size
,
self
.
_params
.
num_char_classes
])
)
def
get_image_feature
(
self
,
char_index
):
def
get_image_feature
(
self
,
char_index
):
"""Returns a subset of image features for a character.
"""Returns a subset of image features for a character.
...
@@ -352,7 +352,7 @@ class Attention(SequenceLayerBase):
...
@@ -352,7 +352,7 @@ class Attention(SequenceLayerBase):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Attention
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
Attention
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
_zero_label
=
tf
.
zeros
(
self
.
_zero_label
=
tf
.
zeros
(
[
self
.
_batch_size
,
self
.
_params
.
num_char_classes
])
tf
.
stack
(
[
self
.
_batch_size
,
self
.
_params
.
num_char_classes
])
)
def
get_eval_input
(
self
,
prev
,
i
):
def
get_eval_input
(
self
,
prev
,
i
):
"""See SequenceLayerBase.get_eval_input for details."""
"""See SequenceLayerBase.get_eval_input for details."""
...
...
research/attention_ocr/python/sequence_layers_test.py
View file @
0cceabfc
...
@@ -29,13 +29,13 @@ import sequence_layers
...
@@ -29,13 +29,13 @@ import sequence_layers
def
fake_net
(
batch_size
,
num_features
,
feature_size
):
def
fake_net
(
batch_size
,
num_features
,
feature_size
):
return
tf
.
convert_to_tensor
(
return
tf
.
convert_to_tensor
(
np
.
random
.
uniform
(
size
=
(
batch_size
,
num_features
,
feature_size
)),
value
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
num_features
,
feature_size
)),
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
def
fake_labels
(
batch_size
,
seq_length
,
num_char_classes
):
def
fake_labels
(
batch_size
,
seq_length
,
num_char_classes
):
labels_np
=
tf
.
convert_to_tensor
(
labels_np
=
tf
.
convert_to_tensor
(
np
.
random
.
randint
(
value
=
np
.
random
.
randint
(
low
=
0
,
high
=
num_char_classes
,
size
=
(
batch_size
,
seq_length
)))
low
=
0
,
high
=
num_char_classes
,
size
=
(
batch_size
,
seq_length
)))
return
slim
.
one_hot_encoding
(
labels_np
,
num_classes
=
num_char_classes
)
return
slim
.
one_hot_encoding
(
labels_np
,
num_classes
=
num_char_classes
)
...
...
research/attention_ocr/python/train.py
View file @
0cceabfc
...
@@ -96,16 +96,16 @@ def get_training_hparams():
...
@@ -96,16 +96,16 @@ def get_training_hparams():
def
create_optimizer
(
hparams
):
def
create_optimizer
(
hparams
):
"""Creates optimized based on the specified flags."""
"""Creates optimized based on the specified flags."""
if
hparams
.
optimizer
==
'momentum'
:
if
hparams
.
optimizer
==
'momentum'
:
optimizer
=
tf
.
train
.
MomentumOptimizer
(
optimizer
=
tf
.
compat
.
v1
.
train
.
MomentumOptimizer
(
hparams
.
learning_rate
,
momentum
=
hparams
.
momentum
)
hparams
.
learning_rate
,
momentum
=
hparams
.
momentum
)
elif
hparams
.
optimizer
==
'adam'
:
elif
hparams
.
optimizer
==
'adam'
:
optimizer
=
tf
.
train
.
AdamOptimizer
(
hparams
.
learning_rate
)
optimizer
=
tf
.
compat
.
v1
.
train
.
AdamOptimizer
(
hparams
.
learning_rate
)
elif
hparams
.
optimizer
==
'adadelta'
:
elif
hparams
.
optimizer
==
'adadelta'
:
optimizer
=
tf
.
train
.
AdadeltaOptimizer
(
hparams
.
learning_rate
)
optimizer
=
tf
.
compat
.
v1
.
train
.
AdadeltaOptimizer
(
hparams
.
learning_rate
)
elif
hparams
.
optimizer
==
'adagrad'
:
elif
hparams
.
optimizer
==
'adagrad'
:
optimizer
=
tf
.
train
.
AdagradOptimizer
(
hparams
.
learning_rate
)
optimizer
=
tf
.
compat
.
v1
.
train
.
AdagradOptimizer
(
hparams
.
learning_rate
)
elif
hparams
.
optimizer
==
'rmsprop'
:
elif
hparams
.
optimizer
==
'rmsprop'
:
optimizer
=
tf
.
train
.
RMSPropOptimizer
(
optimizer
=
tf
.
compat
.
v1
.
train
.
RMSPropOptimizer
(
hparams
.
learning_rate
,
momentum
=
hparams
.
momentum
)
hparams
.
learning_rate
,
momentum
=
hparams
.
momentum
)
return
optimizer
return
optimizer
...
@@ -154,14 +154,14 @@ def train(loss, init_fn, hparams):
...
@@ -154,14 +154,14 @@ def train(loss, init_fn, hparams):
def
prepare_training_dir
():
def
prepare_training_dir
():
if
not
tf
.
gfile
.
E
xists
(
FLAGS
.
train_log_dir
):
if
not
tf
.
io
.
gfile
.
e
xists
(
FLAGS
.
train_log_dir
):
logging
.
info
(
'Create a new training directory %s'
,
FLAGS
.
train_log_dir
)
logging
.
info
(
'Create a new training directory %s'
,
FLAGS
.
train_log_dir
)
tf
.
gfile
.
M
ake
D
irs
(
FLAGS
.
train_log_dir
)
tf
.
io
.
gfile
.
m
ake
d
irs
(
FLAGS
.
train_log_dir
)
else
:
else
:
if
FLAGS
.
reset_train_dir
:
if
FLAGS
.
reset_train_dir
:
logging
.
info
(
'Reset the training directory %s'
,
FLAGS
.
train_log_dir
)
logging
.
info
(
'Reset the training directory %s'
,
FLAGS
.
train_log_dir
)
tf
.
gfile
.
DeleteRecursively
(
FLAGS
.
train_log_dir
)
tf
.
io
.
gfile
.
rmtree
(
FLAGS
.
train_log_dir
)
tf
.
gfile
.
M
ake
D
irs
(
FLAGS
.
train_log_dir
)
tf
.
io
.
gfile
.
m
ake
d
irs
(
FLAGS
.
train_log_dir
)
else
:
else
:
logging
.
info
(
'Use already existing training directory %s'
,
logging
.
info
(
'Use already existing training directory %s'
,
FLAGS
.
train_log_dir
)
FLAGS
.
train_log_dir
)
...
@@ -169,7 +169,7 @@ def prepare_training_dir():
...
@@ -169,7 +169,7 @@ def prepare_training_dir():
def
calculate_graph_metrics
():
def
calculate_graph_metrics
():
param_stats
=
model_analyzer
.
print_model_analysis
(
param_stats
=
model_analyzer
.
print_model_analysis
(
tf
.
get_default_graph
(),
tf
.
compat
.
v1
.
get_default_graph
(),
tfprof_options
=
model_analyzer
.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS
)
tfprof_options
=
model_analyzer
.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS
)
return
param_stats
.
total_parameters
return
param_stats
.
total_parameters
...
@@ -186,7 +186,7 @@ def main(_):
...
@@ -186,7 +186,7 @@ def main(_):
# If ps_tasks is zero, the local device is used. When using multiple
# If ps_tasks is zero, the local device is used. When using multiple
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables
# across the different devices.
# across the different devices.
device_setter
=
tf
.
train
.
replica_device_setter
(
device_setter
=
tf
.
compat
.
v1
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
,
merge_devices
=
True
)
FLAGS
.
ps_tasks
,
merge_devices
=
True
)
with
tf
.
device
(
device_setter
):
with
tf
.
device
(
device_setter
):
data
=
data_provider
.
get_data
(
data
=
data_provider
.
get_data
(
...
...
research/attention_ocr/python/utils.py
View file @
0cceabfc
...
@@ -37,16 +37,16 @@ def logits_to_log_prob(logits):
...
@@ -37,16 +37,16 @@ def logits_to_log_prob(logits):
probabilities.
probabilities.
"""
"""
with
tf
.
variable_scope
(
'log_probabilities'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'log_probabilities'
):
reduction_indices
=
len
(
logits
.
shape
.
as_list
())
-
1
reduction_indices
=
len
(
logits
.
shape
.
as_list
())
-
1
max_logits
=
tf
.
reduce_max
(
max_logits
=
tf
.
reduce_max
(
logits
,
reduction_indice
s
=
reduction_indices
,
keep
_
dims
=
True
)
input_tensor
=
logits
,
axi
s
=
reduction_indices
,
keepdims
=
True
)
safe_logits
=
tf
.
subtract
(
logits
,
max_logits
)
safe_logits
=
tf
.
subtract
(
logits
,
max_logits
)
sum_exp
=
tf
.
reduce_sum
(
sum_exp
=
tf
.
reduce_sum
(
tf
.
exp
(
safe_logits
),
input_tensor
=
tf
.
exp
(
safe_logits
),
reduction_indice
s
=
reduction_indices
,
axi
s
=
reduction_indices
,
keep
_
dims
=
True
)
keepdims
=
True
)
log_probs
=
tf
.
subtract
(
safe_logits
,
tf
.
log
(
sum_exp
))
log_probs
=
tf
.
subtract
(
safe_logits
,
tf
.
math
.
log
(
sum_exp
))
return
log_probs
return
log_probs
...
@@ -78,3 +78,20 @@ def variables_to_restore(scope=None, strip_scope=False):
...
@@ -78,3 +78,20 @@ def variables_to_restore(scope=None, strip_scope=False):
return
variable_map
return
variable_map
else
:
else
:
return
{
v
.
op
.
name
:
v
for
v
in
slim
.
get_variables_to_restore
()}
return
{
v
.
op
.
name
:
v
for
v
in
slim
.
get_variables_to_restore
()}
def
ConvertAllInputsToTensors
(
func
):
"""A decorator to convert all function's inputs into tensors.
Args:
func: a function to decorate.
Returns:
A decorated function.
"""
def
FuncWrapper
(
*
args
):
tensors
=
[
tf
.
convert_to_tensor
(
value
=
a
)
for
a
in
args
]
return
func
(
*
tensors
)
return
FuncWrapper
research/autoencoder/AdditiveGaussianNoiseAutoencoderRunner.py
deleted
100644 → 0
View file @
17821c0d
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
sklearn.preprocessing
as
prep
import
tensorflow
as
tf
from
tensorflow.examples.tutorials.mnist
import
input_data
from
autoencoder_models.DenoisingAutoencoder
import
AdditiveGaussianNoiseAutoencoder
mnist
=
input_data
.
read_data_sets
(
'MNIST_data'
,
one_hot
=
True
)
def
standard_scale
(
X_train
,
X_test
):
preprocessor
=
prep
.
StandardScaler
().
fit
(
X_train
)
X_train
=
preprocessor
.
transform
(
X_train
)
X_test
=
preprocessor
.
transform
(
X_test
)
return
X_train
,
X_test
def
get_random_block_from_data
(
data
,
batch_size
):
start_index
=
np
.
random
.
randint
(
0
,
len
(
data
)
-
batch_size
)
return
data
[
start_index
:(
start_index
+
batch_size
)]
X_train
,
X_test
=
standard_scale
(
mnist
.
train
.
images
,
mnist
.
test
.
images
)
n_samples
=
int
(
mnist
.
train
.
num_examples
)
training_epochs
=
20
batch_size
=
128
display_step
=
1
autoencoder
=
AdditiveGaussianNoiseAutoencoder
(
n_input
=
784
,
n_hidden
=
200
,
transfer_function
=
tf
.
nn
.
softplus
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
0.001
),
scale
=
0.01
)
for
epoch
in
range
(
training_epochs
):
avg_cost
=
0.
total_batch
=
int
(
n_samples
/
batch_size
)
# Loop over all batches
for
i
in
range
(
total_batch
):
batch_xs
=
get_random_block_from_data
(
X_train
,
batch_size
)
# Fit training using batch data
cost
=
autoencoder
.
partial_fit
(
batch_xs
)
# Compute average loss
avg_cost
+=
cost
/
n_samples
*
batch_size
# Display logs per epoch step
if
epoch
%
display_step
==
0
:
print
(
"Epoch:"
,
'%d,'
%
(
epoch
+
1
),
"Cost:"
,
"{:.9f}"
.
format
(
avg_cost
))
print
(
"Total cost: "
+
str
(
autoencoder
.
calc_total_cost
(
X_test
)))
research/autoencoder/AutoencoderRunner.py
deleted
100644 → 0
View file @
17821c0d
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
sklearn.preprocessing
as
prep
import
tensorflow
as
tf
from
tensorflow.examples.tutorials.mnist
import
input_data
from
autoencoder_models.Autoencoder
import
Autoencoder
mnist
=
input_data
.
read_data_sets
(
'MNIST_data'
,
one_hot
=
True
)
def
standard_scale
(
X_train
,
X_test
):
preprocessor
=
prep
.
StandardScaler
().
fit
(
X_train
)
X_train
=
preprocessor
.
transform
(
X_train
)
X_test
=
preprocessor
.
transform
(
X_test
)
return
X_train
,
X_test
def
get_random_block_from_data
(
data
,
batch_size
):
start_index
=
np
.
random
.
randint
(
0
,
len
(
data
)
-
batch_size
)
return
data
[
start_index
:(
start_index
+
batch_size
)]
X_train
,
X_test
=
standard_scale
(
mnist
.
train
.
images
,
mnist
.
test
.
images
)
n_samples
=
int
(
mnist
.
train
.
num_examples
)
training_epochs
=
20
batch_size
=
128
display_step
=
1
autoencoder
=
Autoencoder
(
n_layers
=
[
784
,
200
],
transfer_function
=
tf
.
nn
.
softplus
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
0.001
))
for
epoch
in
range
(
training_epochs
):
avg_cost
=
0.
total_batch
=
int
(
n_samples
/
batch_size
)
# Loop over all batches
for
i
in
range
(
total_batch
):
batch_xs
=
get_random_block_from_data
(
X_train
,
batch_size
)
# Fit training using batch data
cost
=
autoencoder
.
partial_fit
(
batch_xs
)
# Compute average loss
avg_cost
+=
cost
/
n_samples
*
batch_size
# Display logs per epoch step
if
epoch
%
display_step
==
0
:
print
(
"Epoch:"
,
'%d,'
%
(
epoch
+
1
),
"Cost:"
,
"{:.9f}"
.
format
(
avg_cost
))
print
(
"Total cost: "
+
str
(
autoencoder
.
calc_total_cost
(
X_test
)))
research/autoencoder/MaskingNoiseAutoencoderRunner.py
deleted
100644 → 0
View file @
17821c0d
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
sklearn.preprocessing
as
prep
import
tensorflow
as
tf
from
tensorflow.examples.tutorials.mnist
import
input_data
from
autoencoder_models.DenoisingAutoencoder
import
MaskingNoiseAutoencoder
mnist
=
input_data
.
read_data_sets
(
'MNIST_data'
,
one_hot
=
True
)
def
standard_scale
(
X_train
,
X_test
):
preprocessor
=
prep
.
StandardScaler
().
fit
(
X_train
)
X_train
=
preprocessor
.
transform
(
X_train
)
X_test
=
preprocessor
.
transform
(
X_test
)
return
X_train
,
X_test
def
get_random_block_from_data
(
data
,
batch_size
):
start_index
=
np
.
random
.
randint
(
0
,
len
(
data
)
-
batch_size
)
return
data
[
start_index
:(
start_index
+
batch_size
)]
X_train
,
X_test
=
standard_scale
(
mnist
.
train
.
images
,
mnist
.
test
.
images
)
n_samples
=
int
(
mnist
.
train
.
num_examples
)
training_epochs
=
100
batch_size
=
128
display_step
=
1
autoencoder
=
MaskingNoiseAutoencoder
(
n_input
=
784
,
n_hidden
=
200
,
transfer_function
=
tf
.
nn
.
softplus
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
0.001
),
dropout_probability
=
0.95
)
for
epoch
in
range
(
training_epochs
):
avg_cost
=
0.
total_batch
=
int
(
n_samples
/
batch_size
)
for
i
in
range
(
total_batch
):
batch_xs
=
get_random_block_from_data
(
X_train
,
batch_size
)
cost
=
autoencoder
.
partial_fit
(
batch_xs
)
avg_cost
+=
cost
/
n_samples
*
batch_size
if
epoch
%
display_step
==
0
:
print
(
"Epoch:"
,
'%d,'
%
(
epoch
+
1
),
"Cost:"
,
"{:.9f}"
.
format
(
avg_cost
))
print
(
"Total cost: "
+
str
(
autoencoder
.
calc_total_cost
(
X_test
)))
Prev
1
…
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment