Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
afd5579f
Commit
afd5579f
authored
Jul 22, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into context_tf2
parents
dcd96e02
567bd18d
Changes
89
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
250 additions
and
211 deletions
+250
-211
orbit/__init__.py
orbit/__init__.py
+2
-0
orbit/controller.py
orbit/controller.py
+3
-8
orbit/controller_test.py
orbit/controller_test.py
+2
-5
orbit/runner.py
orbit/runner.py
+3
-10
orbit/standard_runner.py
orbit/standard_runner.py
+3
-10
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+1
-0
orbit/utils.py
orbit/utils.py
+5
-14
research/attention_ocr/README.md
research/attention_ocr/README.md
+58
-5
research/attention_ocr/python/data_provider.py
research/attention_ocr/python/data_provider.py
+8
-7
research/attention_ocr/python/datasets/fsns.py
research/attention_ocr/python/datasets/fsns.py
+11
-11
research/attention_ocr/python/datasets/fsns_test.py
research/attention_ocr/python/datasets/fsns_test.py
+1
-1
research/attention_ocr/python/datasets/testdata/fsns/download_data.py
...ention_ocr/python/datasets/testdata/fsns/download_data.py
+3
-2
research/attention_ocr/python/demo_inference.py
research/attention_ocr/python/demo_inference.py
+12
-11
research/attention_ocr/python/demo_inference_test.py
research/attention_ocr/python/demo_inference_test.py
+40
-39
research/attention_ocr/python/eval.py
research/attention_ocr/python/eval.py
+3
-3
research/attention_ocr/python/inception_preprocessing.py
research/attention_ocr/python/inception_preprocessing.py
+16
-16
research/attention_ocr/python/metrics.py
research/attention_ocr/python/metrics.py
+20
-18
research/attention_ocr/python/metrics_test.py
research/attention_ocr/python/metrics_test.py
+7
-7
research/attention_ocr/python/model.py
research/attention_ocr/python/model.py
+36
-30
research/attention_ocr/python/model_export.py
research/attention_ocr/python/model_export.py
+16
-14
No files found.
orbit/__init__.py
View file @
afd5579f
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Orbit package definition."""
from
orbit
import
utils
from
orbit.controller
import
Controller
...
...
orbit/controller.py
View file @
afd5579f
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,14 +15,8 @@
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
time
from
typing
import
Callable
,
Optional
,
Text
,
Union
from
absl
import
logging
from
orbit
import
runner
from
orbit
import
utils
...
...
@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name
,
interval
,
steps_per_loop
))
class
Controller
(
object
)
:
class
Controller
:
"""Class that facilitates training and evaluation of models."""
def
__init__
(
...
...
@@ -396,7 +391,7 @@ class Controller(object):
return
False
class
StepTimer
(
object
)
:
class
StepTimer
:
"""Utility class for measuring steps/second."""
def
__init__
(
self
,
step
):
...
...
orbit/controller_test.py
View file @
afd5579f
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,10 +15,6 @@
# ==============================================================================
"""Tests for orbit.controller."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
absl
import
logging
from
absl.testing
import
parameterized
...
...
@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
class
ControllerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
ControllerTest
,
self
).
setUp
()
super
().
setUp
()
self
.
model_dir
=
self
.
get_temp_dir
()
def
test_no_checkpoint
(
self
):
...
...
orbit/runner.py
View file @
afd5579f
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,19 +15,12 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
from
typing
import
Dict
,
Optional
,
Text
import
six
import
tensorflow
as
tf
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractTrainer
(
tf
.
Module
):
class
AbstractTrainer
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining the APIs required for training."""
@
abc
.
abstractmethod
...
...
@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module):
pass
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractEvaluator
(
tf
.
Module
):
class
AbstractEvaluator
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining the APIs required for evaluation."""
@
abc
.
abstractmethod
...
...
orbit/standard_runner.py
View file @
afd5579f
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,21 +15,14 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
from
typing
import
Any
,
Dict
,
Optional
,
Text
from
orbit
import
runner
from
orbit
import
utils
import
six
import
tensorflow
as
tf
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardTrainer
(
runner
.
AbstractTrainer
):
class
StandardTrainer
(
runner
.
AbstractTrainer
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractTrainer APIs."""
def
__init__
(
self
,
...
...
@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer):
self
.
_train_iter
=
None
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardEvaluator
(
runner
.
AbstractEvaluator
):
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def
__init__
(
self
,
eval_dataset
,
use_tf_function
=
True
):
...
...
orbit/standard_runner_test.py
View file @
afd5579f
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
orbit/utils.py
View file @
afd5579f
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,18 +15,12 @@
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
contextlib
import
functools
import
inspect
import
numpy
as
np
import
six
import
tensorflow
as
tf
...
...
@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if
six
.
PY3
:
argspec
=
inspect
.
getfullargspec
(
dataset_or_fn
)
else
:
argspec
=
inspect
.
getargspec
(
dataset_or_fn
)
# pylint: disable=deprecated-method
argspec
=
inspect
.
getfullargspec
(
dataset_or_fn
)
args_names
=
argspec
.
args
if
"input_context"
in
args_names
:
...
...
@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
class
SummaryManager
(
object
)
:
class
SummaryManager
:
"""A class manages writing summaries."""
def
__init__
(
self
,
summary_dir
,
summary_fn
,
global_step
=
None
):
...
...
@@ -201,8 +193,7 @@ class SummaryManager(object):
self
.
_summary_fn
(
name
,
tensor
,
step
=
self
.
_global_step
)
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
Trigger
(
object
):
class
Trigger
(
metaclass
=
abc
.
ABCMeta
):
"""An abstract class representing a "trigger" for some event."""
@
abc
.
abstractmethod
...
...
@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger):
self
.
_last_trigger_value
=
0
class
EpochHelper
(
object
)
:
class
EpochHelper
:
"""A Helper class to handle epochs in Customized Training Loop."""
def
__init__
(
self
,
epoch_steps
,
global_step
):
...
...
research/attention_ocr/README.md
View file @
afd5579f
#
# Attention-based Extraction of Structured Information from Street View Imagery
# Attention-based Extraction of Structured Information from Street View Imagery
[

](https://paperswithcode.com/sota/optical-character-recognition-on-fsns-test?p=attention-based-extraction-of-structured)
[

](https://arxiv.org/abs/1704.03549)
...
...
@@ -7,14 +7,20 @@
*A TensorFlow model for real-world image text extraction problems.*
This folder contains the code needed to train a new Attention OCR model on the
[
FSNS dataset
][
FSNS
]
dataset to transcribe street names in France. You can
also use it to train it on your own data.
[
FSNS dataset
][
FSNS
]
to transcribe street names in France. You can also train the code on your own data.
More details can be found in our paper:
[
"Attention-based Extraction of Structured Information from Street View
Imagery"
](
https://arxiv.org/abs/1704.03549
)
## Description
*
Paper presents a model based on ConvNets, RNN's and a novel attention mechanism.
Achieves
**84.2%**
on FSNS beating the previous benchmark (
**72.46%**
). Also studies
the speed/accuracy tradeoff that results from using CNN feature extractors of
different depths.
## Contacts
Authors
...
...
@@ -22,7 +28,18 @@ Authors
*
Zbigniew Wojna (zbigniewwojna@gmail.com)
*
Alexander Gorban (gorban@google.com)
Maintainer: Xavier Gibert
[
@xavigibert
](
https://github.com/xavigibert
)
Maintainer
*
Xavier Gibert (
[
@xavigibert
](
https://github.com/xavigibert
)
)
## Table of Contents
*
[
Requirements
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#requirements
)
*
[
Dataset
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#dataset
)
*
[
How to use this code
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-this-code
)
*
[
Using your own image data
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#using-your-own-image-data
)
*
[
How to use a pre-trained model
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-a-pre-trained-model
)
*
[
Disclaimer
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#disclaimer
)
## Requirements
...
...
@@ -49,6 +66,42 @@ cd ..
[
TF
]:
https://www.tensorflow.org/install/
[
FSNS
]:
https://github.com/tensorflow/models/tree/master/research/street
## Dataset
The French Street Name Signs (FSNS) dataset is split into subsets,
each of which is composed of multiple files. Note that these datasets
are very large. The approximate sizes are:
*
Train: 512 files of 300MB each.
*
Validation: 64 files of 40MB each.
*
Test: 64 files of 50MB each.
*
The datasets download includes a directory
`testdata`
that contains
some small datasets that are big enough to test that models can
actually learn something.
*
Total: around 158GB
The download paths are in the following list:
```
https://download.tensorflow.org/data/fsns-20160927/charset_size=134.txt
https://download.tensorflow.org/data/fsns-20160927/test/test-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/test/test-00063-of-00064
https://download.tensorflow.org/data/fsns-20160927/testdata/arial-32-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/mnist-sample-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/numbers-16-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/train/train-00000-of-00512
...
https://download.tensorflow.org/data/fsns-20160927/train/train-00511-of-00512
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00063-of-00064
```
All URLs are stored in the
[
research/street
](
https://github.com/tensorflow/models/tree/master/research/street
)
repository in the text file
`python/fsns_urls.txt`
.
## How to use this code
To run all unit tests:
...
...
@@ -80,7 +133,7 @@ tar xf attention_ocr_2017_08_09.tar.gz
python train.py --checkpoint=model.ckpt-399731
```
##
How to use
your own image data
to train the model
##
Using
your own image data
You need to define a new dataset. There are two options:
...
...
research/attention_ocr/python/data_provider.py
View file @
afd5579f
...
...
@@ -56,14 +56,14 @@ def augment_image(image):
Returns:
Distorted Tensor image of the same shape.
"""
with
tf
.
variable_scope
(
'AugmentImage'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'AugmentImage'
):
height
=
image
.
get_shape
().
dims
[
0
].
value
width
=
image
.
get_shape
().
dims
[
1
].
value
# Random crop cut from the street sign image, resized to the same size.
# Assures that the crop is covers at least 0.8 area of the input image.
bbox_begin
,
bbox_size
,
_
=
tf
.
image
.
sample_distorted_bounding_box
(
tf
.
shape
(
image
),
image_size
=
tf
.
shape
(
input
=
image
),
bounding_boxes
=
tf
.
zeros
([
0
,
0
,
4
]),
min_object_covered
=
0.8
,
aspect_ratio_range
=
[
0.8
,
1.2
],
...
...
@@ -74,7 +74,7 @@ def augment_image(image):
# Randomly chooses one of the 4 interpolation methods
distorted_image
=
inception_preprocessing
.
apply_with_random_selector
(
distorted_image
,
lambda
x
,
method
:
tf
.
image
.
resize
_images
(
x
,
[
height
,
width
],
method
),
lambda
x
,
method
:
tf
.
image
.
resize
(
x
,
[
height
,
width
],
method
),
num_cases
=
4
)
distorted_image
.
set_shape
([
height
,
width
,
3
])
...
...
@@ -99,9 +99,10 @@ def central_crop(image, crop_size):
Returns:
A tensor of shape [crop_height, crop_width, channels].
"""
with
tf
.
variable_scope
(
'CentralCrop'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'CentralCrop'
):
target_width
,
target_height
=
crop_size
image_height
,
image_width
=
tf
.
shape
(
image
)[
0
],
tf
.
shape
(
image
)[
1
]
image_height
,
image_width
=
tf
.
shape
(
input
=
image
)[
0
],
tf
.
shape
(
input
=
image
)[
1
]
assert_op1
=
tf
.
Assert
(
tf
.
greater_equal
(
image_height
,
target_height
),
[
'image_height < target_height'
,
image_height
,
target_height
])
...
...
@@ -129,7 +130,7 @@ def preprocess_image(image, augment=False, central_crop_size=None,
A float32 tensor of shape [H x W x 3] with RGB values in the required
range.
"""
with
tf
.
variable_scope
(
'PreprocessImage'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'PreprocessImage'
):
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
if
augment
or
central_crop_size
:
if
num_towers
==
1
:
...
...
@@ -182,7 +183,7 @@ def get_data(dataset,
image_orig
,
augment
,
central_crop_size
,
num_towers
=
dataset
.
num_of_views
)
label_one_hot
=
slim
.
one_hot_encoding
(
label
,
dataset
.
num_char_classes
)
images
,
images_orig
,
labels
,
labels_one_hot
=
(
tf
.
train
.
shuffle_batch
(
images
,
images_orig
,
labels
,
labels_one_hot
=
(
tf
.
compat
.
v1
.
train
.
shuffle_batch
(
[
image
,
image_orig
,
label
,
label_one_hot
],
batch_size
=
batch_size
,
num_threads
=
shuffle_config
.
num_batching_threads
,
...
...
research/attention_ocr/python/datasets/fsns.py
View file @
afd5579f
...
...
@@ -72,7 +72,7 @@ def read_charset(filename, null_character=u'\u2591'):
"""
pattern
=
re
.
compile
(
r
'(\d+)\t(.+)'
)
charset
=
{}
with
tf
.
gfile
.
GFile
(
filename
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
filename
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
m
=
pattern
.
match
(
line
)
if
m
is
None
:
...
...
@@ -96,9 +96,9 @@ class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler):
self
.
_num_of_views
=
num_of_views
def
tensors_to_item
(
self
,
keys_to_tensors
):
return
tf
.
to_int64
(
return
tf
.
cast
(
self
.
_num_of_views
*
keys_to_tensors
[
self
.
_original_width_key
]
/
keys_to_tensors
[
self
.
_width_key
])
keys_to_tensors
[
self
.
_width_key
]
,
dtype
=
tf
.
int64
)
def
get_split
(
split_name
,
dataset_dir
=
None
,
config
=
None
):
...
...
@@ -133,19 +133,19 @@ def get_split(split_name, dataset_dir=None, config=None):
zero
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int64
)
keys_to_features
=
{
'image/encoded'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/format'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'png'
),
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'png'
),
'image/width'
:
tf
.
FixedLenFeature
([
1
],
tf
.
int64
,
default_value
=
zero
),
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
,
default_value
=
zero
),
'image/orig_width'
:
tf
.
FixedLenFeature
([
1
],
tf
.
int64
,
default_value
=
zero
),
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
,
default_value
=
zero
),
'image/class'
:
tf
.
FixedLenFeature
([
config
[
'max_sequence_length'
]],
tf
.
int64
),
tf
.
io
.
FixedLenFeature
([
config
[
'max_sequence_length'
]],
tf
.
int64
),
'image/unpadded_class'
:
tf
.
VarLenFeature
(
tf
.
int64
),
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'image/text'
:
tf
.
FixedLenFeature
([
1
],
tf
.
string
,
default_value
=
''
),
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
string
,
default_value
=
''
),
}
items_to_handlers
=
{
'image'
:
...
...
@@ -171,7 +171,7 @@ def get_split(split_name, dataset_dir=None, config=None):
config
[
'splits'
][
split_name
][
'pattern'
])
return
slim
.
dataset
.
Dataset
(
data_sources
=
file_pattern
,
reader
=
tf
.
TFRecordReader
,
reader
=
tf
.
compat
.
v1
.
TFRecordReader
,
decoder
=
decoder
,
num_samples
=
config
[
'splits'
][
split_name
][
'size'
],
items_to_descriptions
=
config
[
'items_to_descriptions'
],
...
...
research/attention_ocr/python/datasets/fsns_test.py
View file @
afd5579f
...
...
@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase):
image_tf
,
label_tf
=
provider
.
get
([
'image'
,
'label'
])
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
with
slim
.
queues
.
QueueRunners
(
sess
):
image_np
,
label_np
=
sess
.
run
([
image_tf
,
label_tf
])
...
...
research/attention_ocr/python/datasets/testdata/fsns/download_data.py
View file @
afd5579f
...
...
@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5
print
(
'Downloading %s ...'
%
URL
)
urllib
.
request
.
urlretrieve
(
URL
,
DST_ORIG
)
print
(
'Writing %d records from %s to %s ...'
%
(
KEEP_NUM_RECORDS
,
DST_ORIG
,
DST
))
print
(
'Writing %d records from %s to %s ...'
%
(
KEEP_NUM_RECORDS
,
DST_ORIG
,
DST
))
with
tf
.
io
.
TFRecordWriter
(
DST
)
as
writer
:
for
raw_record
in
itertools
.
islice
(
tf
.
python_io
.
tf_record_iterator
(
DST_ORIG
),
KEEP_NUM_RECORDS
):
for
raw_record
in
itertools
.
islice
(
tf
.
compat
.
v1
.
python_io
.
tf_record_iterator
(
DST_ORIG
),
KEEP_NUM_RECORDS
):
writer
.
write
(
raw_record
)
research/attention_ocr/python/demo_inference.py
View file @
afd5579f
...
...
@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
for
i
in
range
(
batch_size
):
path
=
file_pattern
%
i
print
(
"Reading %s"
%
path
)
pil_image
=
PIL
.
Image
.
open
(
tf
.
gfile
.
GFile
(
path
,
'rb'
))
pil_image
=
PIL
.
Image
.
open
(
tf
.
io
.
gfile
.
GFile
(
path
,
'rb'
))
images_actual_data
[
i
,
...]
=
np
.
asarray
(
pil_image
)
return
images_actual_data
...
...
@@ -58,12 +58,13 @@ def create_model(batch_size, dataset_name):
width
,
height
=
get_dataset_image_size
(
dataset_name
)
dataset
=
common_flags
.
create_dataset
(
split_name
=
FLAGS
.
split_name
)
model
=
common_flags
.
create_model
(
num_char_classes
=
dataset
.
num_char_classes
,
seq_length
=
dataset
.
max_sequence_length
,
num_views
=
dataset
.
num_of_views
,
null_code
=
dataset
.
null_code
,
charset
=
dataset
.
charset
)
raw_images
=
tf
.
placeholder
(
tf
.
uint8
,
shape
=
[
batch_size
,
height
,
width
,
3
])
num_char_classes
=
dataset
.
num_char_classes
,
seq_length
=
dataset
.
max_sequence_length
,
num_views
=
dataset
.
num_of_views
,
null_code
=
dataset
.
null_code
,
charset
=
dataset
.
charset
)
raw_images
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
uint8
,
shape
=
[
batch_size
,
height
,
width
,
3
])
images
=
tf
.
map_fn
(
data_provider
.
preprocess_image
,
raw_images
,
dtype
=
tf
.
float32
)
endpoints
=
model
.
create_base
(
images
,
labels_one_hot
=
None
)
...
...
@@ -76,9 +77,9 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_data
=
load_images
(
image_path_pattern
,
batch_size
,
dataset_name
)
session_creator
=
monitored_session
.
ChiefSessionCreator
(
checkpoint_filename_with_path
=
checkpoint
)
checkpoint_filename_with_path
=
checkpoint
)
with
monitored_session
.
MonitoredSession
(
session_creator
=
session_creator
)
as
sess
:
session_creator
=
session_creator
)
as
sess
:
predictions
=
sess
.
run
(
endpoints
.
predicted_text
,
feed_dict
=
{
images_placeholder
:
images_data
})
return
[
pr_bytes
.
decode
(
'utf-8'
)
for
pr_bytes
in
predictions
.
tolist
()]
...
...
@@ -87,10 +88,10 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
def
main
(
_
):
print
(
"Predicted strings:"
)
predictions
=
run
(
FLAGS
.
checkpoint
,
FLAGS
.
batch_size
,
FLAGS
.
dataset_name
,
FLAGS
.
image_path_pattern
)
FLAGS
.
image_path_pattern
)
for
line
in
predictions
:
print
(
line
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
tf
.
compat
.
v1
.
app
.
run
()
research/attention_ocr/python/demo_inference_test.py
View file @
afd5579f
...
...
@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase):
super
(
DemoInferenceTest
,
self
).
setUp
()
for
suffix
in
[
'.meta'
,
'.index'
,
'.data-00000-of-00001'
]:
filename
=
_CHECKPOINT
+
suffix
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
filename
),
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
filename
),
msg
=
'Missing checkpoint file %s. '
'Please download and extract it from %s'
%
(
filename
,
_CHECKPOINT_URL
))
self
.
_batch_size
=
32
tf
.
flags
.
FLAGS
.
dataset_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'datasets/testdata/fsns'
)
tf
.
flags
.
FLAGS
.
dataset_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'datasets/testdata/fsns'
)
def
test_moving_variables_properly_loaded_from_a_checkpoint
(
self
):
batch_size
=
32
...
...
@@ -30,15 +31,15 @@ class DemoInferenceTest(tf.test.TestCase):
images_data
=
demo_inference
.
load_images
(
image_path_pattern
,
batch_size
,
dataset_name
)
tensor_name
=
'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf
=
tf
.
get_default_graph
().
get_tensor_by_name
(
tensor_name
+
':0'
)
reader
=
tf
.
train
.
NewCheckpointReader
(
_CHECKPOINT
)
moving_mean_tf
=
tf
.
compat
.
v1
.
get_default_graph
().
get_tensor_by_name
(
tensor_name
+
':0'
)
reader
=
tf
.
compat
.
v1
.
train
.
NewCheckpointReader
(
_CHECKPOINT
)
moving_mean_expected
=
reader
.
get_tensor
(
tensor_name
)
session_creator
=
monitored_session
.
ChiefSessionCreator
(
checkpoint_filename_with_path
=
_CHECKPOINT
)
checkpoint_filename_with_path
=
_CHECKPOINT
)
with
monitored_session
.
MonitoredSession
(
session_creator
=
session_creator
)
as
sess
:
session_creator
=
session_creator
)
as
sess
:
moving_mean_np
=
sess
.
run
(
moving_mean_tf
,
feed_dict
=
{
images_placeholder
:
images_data
})
...
...
@@ -50,38 +51,38 @@ class DemoInferenceTest(tf.test.TestCase):
'fsns'
,
image_path_pattern
)
self
.
assertEqual
([
u
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░'
,
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░'
,
'Avenue Charles Gounod░░░░░░░░░░░░░░░░'
,
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░'
,
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░'
,
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░'
,
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░'
,
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░'
,
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░'
,
# GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░'
,
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░'
,
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░'
,
'Place de la Mairie░░░░░░░░░░░░░░░░░░░'
,
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de la Libération░░░░░░░░░░░░░░░░░'
,
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░'
,
'Avenue de la Grand Mare░░░░░░░░░░░░░░'
,
'Rue Pierre Brossolette░░░░░░░░░░░░░░░'
,
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░'
,
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░'
,
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░'
,
'Impasse Pierre Mourgues░░░░░░░░░░░░░░'
,
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
u
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░'
,
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░'
,
'Avenue Charles Gounod░░░░░░░░░░░░░░░░'
,
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░'
,
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░'
,
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░'
,
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░'
,
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░'
,
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░'
,
# GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░'
,
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░'
,
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░'
,
'Place de la Mairie░░░░░░░░░░░░░░░░░░░'
,
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue de la Libération░░░░░░░░░░░░░░░░░'
,
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░'
,
'Avenue de la Grand Mare░░░░░░░░░░░░░░'
,
'Rue Pierre Brossolette░░░░░░░░░░░░░░░'
,
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░'
,
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░'
,
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░'
,
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░'
,
'Impasse Pierre Mourgues░░░░░░░░░░░░░░'
,
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
],
predictions
)
...
...
research/attention_ocr/python/eval.py
View file @
afd5579f
...
...
@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None,
def
main
(
_
):
if
not
tf
.
gfile
.
E
xists
(
FLAGS
.
eval_log_dir
):
tf
.
gfile
.
M
ake
D
irs
(
FLAGS
.
eval_log_dir
)
if
not
tf
.
io
.
gfile
.
e
xists
(
FLAGS
.
eval_log_dir
):
tf
.
io
.
gfile
.
m
ake
d
irs
(
FLAGS
.
eval_log_dir
)
dataset
=
common_flags
.
create_dataset
(
split_name
=
FLAGS
.
split_name
)
model
=
common_flags
.
create_model
(
dataset
.
num_char_classes
,
...
...
@@ -62,7 +62,7 @@ def main(_):
eval_ops
=
model
.
create_summaries
(
data
,
endpoints
,
dataset
.
charset
,
is_training
=
False
)
slim
.
get_or_create_global_step
()
session_config
=
tf
.
ConfigProto
(
device_count
=
{
"GPU"
:
0
})
session_config
=
tf
.
compat
.
v1
.
ConfigProto
(
device_count
=
{
"GPU"
:
0
})
slim
.
evaluation
.
evaluation_loop
(
master
=
FLAGS
.
master
,
checkpoint_dir
=
FLAGS
.
train_log_dir
,
...
...
research/attention_ocr/python/inception_preprocessing.py
View file @
afd5579f
...
...
@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases):
The result of func(x, sel), where func receives the value of the
selector as a python integer, but sel is sampled dynamically.
"""
sel
=
tf
.
random
_
uniform
([],
maxval
=
num_cases
,
dtype
=
tf
.
int32
)
sel
=
tf
.
random
.
uniform
([],
maxval
=
num_cases
,
dtype
=
tf
.
int32
)
# Pass the real x only to one of the func calls.
return
control_flow_ops
.
merge
([
func
(
control_flow_ops
.
switch
(
x
,
tf
.
equal
(
sel
,
case
))[
1
],
case
)
...
...
@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
Raises:
ValueError: if color_ordering not in [0, 3]
"""
with
tf
.
name_scope
(
scope
,
'distort_color'
,
[
image
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distort_color'
,
[
image
]):
if
fast_mode
:
if
color_ordering
==
0
:
image
=
tf
.
image
.
random_brightness
(
image
,
max_delta
=
32.
/
255.
)
...
...
@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image,
Returns:
A tuple, a 3-D Tensor cropped_image and the distorted bbox
"""
with
tf
.
name_scope
(
scope
,
'distorted_bounding_box_crop'
,
[
image
,
bbox
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distorted_bounding_box_crop'
,
[
image
,
bbox
]):
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
...
...
@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image,
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box
=
tf
.
image
.
sample_distorted_bounding_box
(
tf
.
shape
(
image
),
image_size
=
tf
.
shape
(
input
=
image
),
bounding_boxes
=
bbox
,
min_object_covered
=
min_object_covered
,
aspect_ratio_range
=
aspect_ratio_range
,
...
...
@@ -188,7 +188,7 @@ def preprocess_for_train(image,
Returns:
3-D float Tensor of distorted image used for training with range [-1, 1].
"""
with
tf
.
name_scope
(
scope
,
'distort_image'
,
[
image
,
height
,
width
,
bbox
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distort_image'
,
[
image
,
height
,
width
,
bbox
]):
if
bbox
is
None
:
bbox
=
tf
.
constant
(
[
0.0
,
0.0
,
1.0
,
1.0
],
dtype
=
tf
.
float32
,
shape
=
[
1
,
1
,
4
])
...
...
@@ -198,7 +198,7 @@ def preprocess_for_train(image,
# the coordinates are ordered [ymin, xmin, ymax, xmax].
image_with_box
=
tf
.
image
.
draw_bounding_boxes
(
tf
.
expand_dims
(
image
,
0
),
bbox
)
tf
.
summary
.
image
(
'image_with_bounding_boxes'
,
image_with_box
)
tf
.
compat
.
v1
.
summary
.
image
(
'image_with_bounding_boxes'
,
image_with_box
)
distorted_image
,
distorted_bbox
=
distorted_bounding_box_crop
(
image
,
bbox
)
# Restore the shape since the dynamic slice based upon the bbox_size loses
...
...
@@ -206,8 +206,8 @@ def preprocess_for_train(image,
distorted_image
.
set_shape
([
None
,
None
,
3
])
image_with_distorted_box
=
tf
.
image
.
draw_bounding_boxes
(
tf
.
expand_dims
(
image
,
0
),
distorted_bbox
)
tf
.
summary
.
image
(
'images_with_distorted_bounding_box'
,
image_with_distorted_box
)
tf
.
compat
.
v1
.
summary
.
image
(
'images_with_distorted_bounding_box'
,
image_with_distorted_box
)
# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
...
...
@@ -218,11 +218,11 @@ def preprocess_for_train(image,
num_resize_cases
=
1
if
fast_mode
else
4
distorted_image
=
apply_with_random_selector
(
distorted_image
,
lambda
x
,
method
:
tf
.
image
.
resize
_images
(
x
,
[
height
,
width
],
method
=
method
),
lambda
x
,
method
:
tf
.
image
.
resize
(
x
,
[
height
,
width
],
method
=
method
),
num_cases
=
num_resize_cases
)
tf
.
summary
.
image
(
'cropped_resized_image'
,
tf
.
expand_dims
(
distorted_image
,
0
))
tf
.
compat
.
v1
.
summary
.
image
(
'cropped_resized_image'
,
tf
.
expand_dims
(
distorted_image
,
0
))
# Randomly flip the image horizontally.
distorted_image
=
tf
.
image
.
random_flip_left_right
(
distorted_image
)
...
...
@@ -233,8 +233,8 @@ def preprocess_for_train(image,
lambda
x
,
ordering
:
distort_color
(
x
,
ordering
,
fast_mode
),
num_cases
=
4
)
tf
.
summary
.
image
(
'final_distorted_image'
,
tf
.
expand_dims
(
distorted_image
,
0
))
tf
.
compat
.
v1
.
summary
.
image
(
'final_distorted_image'
,
tf
.
expand_dims
(
distorted_image
,
0
))
distorted_image
=
tf
.
subtract
(
distorted_image
,
0.5
)
distorted_image
=
tf
.
multiply
(
distorted_image
,
2.0
)
return
distorted_image
...
...
@@ -265,7 +265,7 @@ def preprocess_for_eval(image,
Returns:
3-D float Tensor of prepared image.
"""
with
tf
.
name_scope
(
scope
,
'eval_image'
,
[
image
,
height
,
width
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'eval_image'
,
[
image
,
height
,
width
]):
if
image
.
dtype
!=
tf
.
float32
:
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
# Crop the central region of the image with an area containing 87.5% of
...
...
@@ -276,8 +276,8 @@ def preprocess_for_eval(image,
if
height
and
width
:
# Resize the image to the specified height and width.
image
=
tf
.
expand_dims
(
image
,
0
)
image
=
tf
.
image
.
resize
_bilinear
(
image
,
[
height
,
width
],
align_corners
=
False
)
image
=
tf
.
image
.
resize
(
image
,
[
height
,
width
],
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
squeeze
(
image
,
[
0
])
image
=
tf
.
subtract
(
image
,
0.5
)
image
=
tf
.
multiply
(
image
,
2.0
)
...
...
research/attention_ocr/python/metrics.py
View file @
afd5579f
...
...
@@ -34,20 +34,21 @@ def char_accuracy(predictions, targets, rej_char, streaming=False):
a update_ops for execution and value tensor whose value on evaluation
returns the total character accuracy.
"""
with
tf
.
variable_scope
(
'CharAccuracy'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'CharAccuracy'
):
predictions
.
get_shape
().
assert_is_compatible_with
(
targets
.
get_shape
())
targets
=
tf
.
to_int32
(
targets
)
targets
=
tf
.
cast
(
targets
,
dtype
=
tf
.
int32
)
const_rej_char
=
tf
.
constant
(
rej_char
,
shape
=
targets
.
get_shape
())
weights
=
tf
.
to_float
(
tf
.
not_equal
(
targets
,
const_rej_char
))
correct_chars
=
tf
.
to_float
(
tf
.
equal
(
predictions
,
targets
))
accuracy_per_example
=
tf
.
div
(
tf
.
reduce_sum
(
tf
.
multiply
(
correct_chars
,
weights
),
1
),
tf
.
reduce_sum
(
weights
,
1
))
weights
=
tf
.
cast
(
tf
.
not_equal
(
targets
,
const_rej_char
),
dtype
=
tf
.
float32
)
correct_chars
=
tf
.
cast
(
tf
.
equal
(
predictions
,
targets
),
dtype
=
tf
.
float32
)
accuracy_per_example
=
tf
.
compat
.
v1
.
div
(
tf
.
reduce_sum
(
input_tensor
=
tf
.
multiply
(
correct_chars
,
weights
),
axis
=
1
),
tf
.
reduce_sum
(
input_tensor
=
weights
,
axis
=
1
))
if
streaming
:
return
tf
.
contrib
.
metrics
.
streaming_mean
(
accuracy_per_example
)
else
:
return
tf
.
reduce_mean
(
accuracy_per_example
)
return
tf
.
reduce_mean
(
input_tensor
=
accuracy_per_example
)
def
sequence_accuracy
(
predictions
,
targets
,
rej_char
,
streaming
=
False
):
...
...
@@ -66,25 +67,26 @@ def sequence_accuracy(predictions, targets, rej_char, streaming=False):
returns the total sequence accuracy.
"""
with
tf
.
variable_scope
(
'SequenceAccuracy'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'SequenceAccuracy'
):
predictions
.
get_shape
().
assert_is_compatible_with
(
targets
.
get_shape
())
targets
=
tf
.
to_int32
(
targets
)
targets
=
tf
.
cast
(
targets
,
dtype
=
tf
.
int32
)
const_rej_char
=
tf
.
constant
(
rej_char
,
shape
=
targets
.
get_shape
(),
dtype
=
tf
.
int32
)
include_mask
=
tf
.
not_equal
(
targets
,
const_rej_char
)
include_predictions
=
tf
.
to_int32
(
tf
.
where
(
include_mask
,
predictions
,
tf
.
zeros_like
(
predictions
)
+
rej_char
))
correct_chars
=
tf
.
to_float
(
tf
.
equal
(
include_predictions
,
targets
))
include_predictions
=
tf
.
cast
(
tf
.
compat
.
v1
.
where
(
include_mask
,
predictions
,
tf
.
zeros_like
(
predictions
)
+
rej_char
),
dtype
=
tf
.
int32
)
correct_chars
=
tf
.
cast
(
tf
.
equal
(
include_predictions
,
targets
),
dtype
=
tf
.
float32
)
correct_chars_counts
=
tf
.
cast
(
tf
.
reduce_sum
(
correct_chars
,
reduction_indice
s
=
[
1
]),
dtype
=
tf
.
int32
)
tf
.
reduce_sum
(
input_tensor
=
correct_chars
,
axi
s
=
[
1
]),
dtype
=
tf
.
int32
)
target_length
=
targets
.
get_shape
().
dims
[
1
].
value
target_chars_counts
=
tf
.
constant
(
target_length
,
shape
=
correct_chars_counts
.
get_shape
())
accuracy_per_example
=
tf
.
to_floa
t
(
tf
.
equal
(
correct_chars_counts
,
target_chars_counts
))
accuracy_per_example
=
tf
.
cas
t
(
tf
.
equal
(
correct_chars_counts
,
target_chars_counts
)
,
dtype
=
tf
.
float32
)
if
streaming
:
return
tf
.
contrib
.
metrics
.
streaming_mean
(
accuracy_per_example
)
else
:
return
tf
.
reduce_mean
(
accuracy_per_example
)
return
tf
.
reduce_mean
(
input_tensor
=
accuracy_per_example
)
research/attention_ocr/python/metrics_test.py
View file @
afd5579f
...
...
@@ -38,8 +38,8 @@ class AccuracyTest(tf.test.TestCase):
A session object that should be used as a context manager.
"""
with
self
.
cached_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
local_variables_initializer
())
yield
sess
def
_fake_labels
(
self
):
...
...
@@ -55,7 +55,7 @@ class AccuracyTest(tf.test.TestCase):
return
incorrect
def
test_sequence_accuracy_identical_samples
(
self
):
labels_tf
=
tf
.
convert_to_tensor
(
self
.
_fake_labels
())
labels_tf
=
tf
.
convert_to_tensor
(
value
=
self
.
_fake_labels
())
accuracy_tf
=
metrics
.
sequence_accuracy
(
labels_tf
,
labels_tf
,
self
.
rej_char
)
...
...
@@ -66,9 +66,9 @@ class AccuracyTest(tf.test.TestCase):
def
test_sequence_accuracy_one_char_difference
(
self
):
ground_truth_np
=
self
.
_fake_labels
()
ground_truth_tf
=
tf
.
convert_to_tensor
(
ground_truth_np
)
ground_truth_tf
=
tf
.
convert_to_tensor
(
value
=
ground_truth_np
)
prediction_tf
=
tf
.
convert_to_tensor
(
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
value
=
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
accuracy_tf
=
metrics
.
sequence_accuracy
(
prediction_tf
,
ground_truth_tf
,
self
.
rej_char
)
...
...
@@ -80,9 +80,9 @@ class AccuracyTest(tf.test.TestCase):
def
test_char_accuracy_one_char_difference_with_padding
(
self
):
ground_truth_np
=
self
.
_fake_labels
()
ground_truth_tf
=
tf
.
convert_to_tensor
(
ground_truth_np
)
ground_truth_tf
=
tf
.
convert_to_tensor
(
value
=
ground_truth_np
)
prediction_tf
=
tf
.
convert_to_tensor
(
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
value
=
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
accuracy_tf
=
metrics
.
char_accuracy
(
prediction_tf
,
ground_truth_tf
,
self
.
rej_char
)
...
...
research/attention_ocr/python/model.py
View file @
afd5579f
...
...
@@ -92,8 +92,8 @@ class CharsetMapper(object):
Args:
ids: a tensor with shape [batch_size, max_sequence_length]
"""
return
tf
.
reduce_join
(
self
.
table
.
lookup
(
tf
.
to_int64
(
ids
)),
reduction_indice
s
=
1
)
return
tf
.
strings
.
reduce_join
(
inputs
=
self
.
table
.
lookup
(
tf
.
cast
(
ids
,
dtype
=
tf
.
int64
)),
axi
s
=
1
)
def
get_softmax_loss_fn
(
label_smoothing
):
...
...
@@ -110,7 +110,7 @@ def get_softmax_loss_fn(label_smoothing):
def
loss_fn
(
labels
,
logits
):
return
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
labels
))
logits
=
logits
,
labels
=
tf
.
stop_gradient
(
labels
))
)
else
:
def
loss_fn
(
labels
,
logits
):
...
...
@@ -140,7 +140,7 @@ def get_tensor_dimensions(tensor):
raise
ValueError
(
'Incompatible shape: len(tensor.get_shape().dims) != 4 (%d != 4)'
%
len
(
tensor
.
get_shape
().
dims
))
batch_size
=
tf
.
shape
(
tensor
)[
0
]
batch_size
=
tf
.
shape
(
input
=
tensor
)[
0
]
height
=
tensor
.
get_shape
().
dims
[
1
].
value
width
=
tensor
.
get_shape
().
dims
[
2
].
value
num_features
=
tensor
.
get_shape
().
dims
[
3
].
value
...
...
@@ -161,8 +161,8 @@ def lookup_indexed_value(indices, row_vecs):
A tensor of shape (batch, ) formed by row_vecs[i, indices[i]].
"""
gather_indices
=
tf
.
stack
((
tf
.
range
(
tf
.
shape
(
row_vecs
)[
0
],
dtype
=
tf
.
int32
),
tf
.
cast
(
indices
,
tf
.
int32
)),
axis
=
1
)
tf
.
shape
(
input
=
row_vecs
)[
0
],
dtype
=
tf
.
int32
),
tf
.
cast
(
indices
,
tf
.
int32
)),
axis
=
1
)
return
tf
.
gather_nd
(
row_vecs
,
gather_indices
)
...
...
@@ -181,7 +181,7 @@ def max_char_logprob_cumsum(char_log_prob):
so the same function can be used regardless whether use_length_predictions
is true or false.
"""
max_char_log_prob
=
tf
.
reduce_max
(
char_log_prob
,
reduction_indice
s
=
2
)
max_char_log_prob
=
tf
.
reduce_max
(
input_tensor
=
char_log_prob
,
axi
s
=
2
)
# For an input array [a, b, c]) tf.cumsum returns [a, a + b, a + b + c] if
# exclusive set to False (default).
return
tf
.
cumsum
(
max_char_log_prob
,
axis
=
1
,
exclusive
=
False
)
...
...
@@ -203,7 +203,7 @@ def find_length_by_null(predicted_chars, null_code):
A [batch, ] tensor which stores the sequence length for each sample.
"""
return
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
null_code
,
predicted_chars
),
tf
.
int32
),
axis
=
1
)
input_tensor
=
tf
.
cast
(
tf
.
not_equal
(
null_code
,
predicted_chars
),
tf
.
int32
),
axis
=
1
)
def
axis_pad
(
tensor
,
axis
,
before
=
0
,
after
=
0
,
constant_values
=
0.0
):
...
...
@@ -248,7 +248,8 @@ def null_based_length_prediction(chars_log_prob, null_code):
element #seq_length - is the probability of length=seq_length.
predicted_length is a tensor with shape [batch].
"""
predicted_chars
=
tf
.
to_int32
(
tf
.
argmax
(
chars_log_prob
,
axis
=
2
))
predicted_chars
=
tf
.
cast
(
tf
.
argmax
(
input
=
chars_log_prob
,
axis
=
2
),
dtype
=
tf
.
int32
)
# We do right pad to support sequences with seq_length elements.
text_log_prob
=
max_char_logprob_cumsum
(
axis_pad
(
chars_log_prob
,
axis
=
1
,
after
=
1
))
...
...
@@ -334,9 +335,9 @@ class Model(object):
"""
mparams
=
self
.
_mparams
[
'conv_tower_fn'
]
logging
.
debug
(
'Using final_endpoint=%s'
,
mparams
.
final_endpoint
)
with
tf
.
variable_scope
(
'conv_tower_fn/INCE'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'conv_tower_fn/INCE'
):
if
reuse
:
tf
.
get_variable_scope
().
reuse_variables
()
tf
.
compat
.
v1
.
get_variable_scope
().
reuse_variables
()
with
slim
.
arg_scope
(
inception
.
inception_v3_arg_scope
()):
with
slim
.
arg_scope
([
slim
.
batch_norm
,
slim
.
dropout
],
is_training
=
is_training
):
...
...
@@ -372,7 +373,7 @@ class Model(object):
def
sequence_logit_fn
(
self
,
net
,
labels_one_hot
):
mparams
=
self
.
_mparams
[
'sequence_logit_fn'
]
# TODO(gorban): remove /alias suffixes from the scopes.
with
tf
.
variable_scope
(
'sequence_logit_fn/SQLR'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'sequence_logit_fn/SQLR'
):
layer_class
=
sequence_layers
.
get_layer_class
(
mparams
.
use_attention
,
mparams
.
use_autoregression
)
layer
=
layer_class
(
net
,
labels_one_hot
,
self
.
_params
,
mparams
)
...
...
@@ -392,7 +393,7 @@ class Model(object):
]
xy_flat_shape
=
(
batch_size
,
1
,
height
*
width
,
num_features
)
nets_for_merge
=
[]
with
tf
.
variable_scope
(
'max_pool_views'
,
values
=
nets_list
):
with
tf
.
compat
.
v1
.
variable_scope
(
'max_pool_views'
,
values
=
nets_list
):
for
net
in
nets_list
:
nets_for_merge
.
append
(
tf
.
reshape
(
net
,
xy_flat_shape
))
merged_net
=
tf
.
concat
(
nets_for_merge
,
1
)
...
...
@@ -413,10 +414,11 @@ class Model(object):
Returns:
A tensor of shape [batch_size, seq_length, features_size].
"""
with
tf
.
variable_scope
(
'pool_views_fn/STCK'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'pool_views_fn/STCK'
):
net
=
tf
.
concat
(
nets
,
1
)
batch_size
=
tf
.
shape
(
net
)[
0
]
image_size
=
net
.
get_shape
().
dims
[
1
].
value
*
net
.
get_shape
().
dims
[
2
].
value
batch_size
=
tf
.
shape
(
input
=
net
)[
0
]
image_size
=
net
.
get_shape
().
dims
[
1
].
value
*
\
net
.
get_shape
().
dims
[
2
].
value
feature_size
=
net
.
get_shape
().
dims
[
3
].
value
return
tf
.
reshape
(
net
,
tf
.
stack
([
batch_size
,
image_size
,
feature_size
]))
...
...
@@ -438,11 +440,13 @@ class Model(object):
with shape [batch_size x seq_length].
"""
log_prob
=
utils
.
logits_to_log_prob
(
chars_logit
)
ids
=
tf
.
to_int32
(
tf
.
argmax
(
log_prob
,
axis
=
2
),
name
=
'predicted_chars'
)
ids
=
tf
.
cast
(
tf
.
argmax
(
input
=
log_prob
,
axis
=
2
),
name
=
'predicted_chars'
,
dtype
=
tf
.
int32
)
mask
=
tf
.
cast
(
slim
.
one_hot_encoding
(
ids
,
self
.
_params
.
num_char_classes
),
tf
.
bool
)
all_scores
=
tf
.
nn
.
softmax
(
chars_logit
)
selected_scores
=
tf
.
boolean_mask
(
all_scores
,
mask
,
name
=
'char_scores'
)
selected_scores
=
tf
.
boolean_mask
(
tensor
=
all_scores
,
mask
=
mask
,
name
=
'char_scores'
)
scores
=
tf
.
reshape
(
selected_scores
,
shape
=
(
-
1
,
self
.
_params
.
seq_length
),
...
...
@@ -499,7 +503,7 @@ class Model(object):
images
=
tf
.
subtract
(
images
,
0.5
)
images
=
tf
.
multiply
(
images
,
2.5
)
with
tf
.
variable_scope
(
scope
,
reuse
=
reuse
):
with
tf
.
compat
.
v1
.
variable_scope
(
scope
,
reuse
=
reuse
):
views
=
tf
.
split
(
value
=
images
,
num_or_size_splits
=
self
.
_params
.
num_views
,
axis
=
2
)
logging
.
debug
(
'Views=%d single view: %s'
,
len
(
views
),
views
[
0
])
...
...
@@ -566,7 +570,7 @@ class Model(object):
# multiple losses including regularization losses.
self
.
sequence_loss_fn
(
endpoints
.
chars_logit
,
data
.
labels
)
total_loss
=
slim
.
losses
.
get_total_loss
()
tf
.
summary
.
scalar
(
'TotalLoss'
,
total_loss
)
tf
.
compat
.
v1
.
summary
.
scalar
(
'TotalLoss'
,
total_loss
)
return
total_loss
def
label_smoothing_regularization
(
self
,
chars_labels
,
weight
=
0.1
):
...
...
@@ -605,7 +609,7 @@ class Model(object):
A Tensor with shape [batch_size] - the log-perplexity for each sequence.
"""
mparams
=
self
.
_mparams
[
'sequence_loss_fn'
]
with
tf
.
variable_scope
(
'sequence_loss_fn/SLF'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'sequence_loss_fn/SLF'
):
if
mparams
.
label_smoothing
>
0
:
smoothed_one_hot_labels
=
self
.
label_smoothing_regularization
(
chars_labels
,
mparams
.
label_smoothing
)
...
...
@@ -625,7 +629,7 @@ class Model(object):
shape
=
(
batch_size
,
seq_length
),
dtype
=
tf
.
int64
)
known_char
=
tf
.
not_equal
(
chars_labels
,
reject_char
)
weights
=
tf
.
to_floa
t
(
known_char
)
weights
=
tf
.
cas
t
(
known_char
,
dtype
=
tf
.
float32
)
logits_list
=
tf
.
unstack
(
chars_logits
,
axis
=
1
)
weights_list
=
tf
.
unstack
(
weights
,
axis
=
1
)
...
...
@@ -635,7 +639,7 @@ class Model(object):
weights_list
,
softmax_loss_function
=
get_softmax_loss_fn
(
mparams
.
label_smoothing
),
average_across_timesteps
=
mparams
.
average_across_timesteps
)
tf
.
losses
.
add_loss
(
loss
)
tf
.
compat
.
v1
.
losses
.
add_loss
(
loss
)
return
loss
def
create_summaries
(
self
,
data
,
endpoints
,
charset
,
is_training
):
...
...
@@ -665,13 +669,14 @@ class Model(object):
# tf.summary.text(sname('text/pr'), pr_text)
# gt_text = charset_mapper.get_text(data.labels[:max_outputs,:])
# tf.summary.text(sname('text/gt'), gt_text)
tf
.
summary
.
image
(
sname
(
'image'
),
data
.
images
,
max_outputs
=
max_outputs
)
tf
.
compat
.
v1
.
summary
.
image
(
sname
(
'image'
),
data
.
images
,
max_outputs
=
max_outputs
)
if
is_training
:
tf
.
summary
.
image
(
tf
.
compat
.
v1
.
summary
.
image
(
sname
(
'image/orig'
),
data
.
images_orig
,
max_outputs
=
max_outputs
)
for
var
in
tf
.
trainable_variables
():
tf
.
summary
.
histogram
(
var
.
op
.
name
,
var
)
for
var
in
tf
.
compat
.
v1
.
trainable_variables
():
tf
.
compat
.
v1
.
summary
.
histogram
(
var
.
op
.
name
,
var
)
return
None
else
:
...
...
@@ -700,7 +705,8 @@ class Model(object):
for
name
,
value
in
names_to_values
.
items
():
summary_name
=
'eval/'
+
name
tf
.
summary
.
scalar
(
summary_name
,
tf
.
Print
(
value
,
[
value
],
summary_name
))
tf
.
compat
.
v1
.
summary
.
scalar
(
summary_name
,
tf
.
compat
.
v1
.
Print
(
value
,
[
value
],
summary_name
))
return
list
(
names_to_updates
.
values
())
def
create_init_fn_to_restore
(
self
,
...
...
@@ -733,9 +739,9 @@ class Model(object):
logging
.
info
(
'variables_to_restore:
\n
%s'
,
utils
.
variables_to_restore
().
keys
())
logging
.
info
(
'moving_average_variables:
\n
%s'
,
[
v
.
op
.
name
for
v
in
tf
.
moving_average_variables
()])
[
v
.
op
.
name
for
v
in
tf
.
compat
.
v1
.
moving_average_variables
()])
logging
.
info
(
'trainable_variables:
\n
%s'
,
[
v
.
op
.
name
for
v
in
tf
.
trainable_variables
()])
[
v
.
op
.
name
for
v
in
tf
.
compat
.
v1
.
trainable_variables
()])
if
master_checkpoint
:
assign_from_checkpoint
(
utils
.
variables_to_restore
(),
master_checkpoint
)
...
...
research/attention_ocr/python/model_export.py
View file @
afd5579f
...
...
@@ -42,7 +42,8 @@ flags.DEFINE_integer(
'image_height'
,
None
,
'Image height used during training(or crop height if used)'
' If not set, the dataset default is used instead.'
)
flags
.
DEFINE_string
(
'work_dir'
,
'/tmp'
,
'A directory to store temporary files.'
)
flags
.
DEFINE_string
(
'work_dir'
,
'/tmp'
,
'A directory to store temporary files.'
)
flags
.
DEFINE_integer
(
'version_number'
,
1
,
'Version number of the model'
)
flags
.
DEFINE_bool
(
'export_for_serving'
,
True
,
...
...
@@ -116,7 +117,7 @@ def export_model(export_dir,
image_height
=
crop_image_height
or
dataset_image_height
if
export_for_serving
:
images_orig
=
tf
.
placeholder
(
images_orig
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
string
,
shape
=
[
batch_size
],
name
=
'tf_example'
)
images_orig_float
=
model_export_lib
.
generate_tfexample_image
(
images_orig
,
...
...
@@ -126,22 +127,23 @@ def export_model(export_dir,
name
=
'float_images'
)
else
:
images_shape
=
(
batch_size
,
image_height
,
image_width
,
image_depth
)
images_orig
=
tf
.
placeholder
(
images_orig
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
uint8
,
shape
=
images_shape
,
name
=
'original_image'
)
images_orig_float
=
tf
.
image
.
convert_image_dtype
(
images_orig
,
dtype
=
tf
.
float32
,
name
=
'float_images'
)
endpoints
=
model
.
create_base
(
images_orig_float
,
labels_one_hot
=
None
)
sess
=
tf
.
Session
()
saver
=
tf
.
train
.
Saver
(
slim
.
get_variables_to_restore
(),
sharded
=
True
)
sess
=
tf
.
compat
.
v1
.
Session
()
saver
=
tf
.
compat
.
v1
.
train
.
Saver
(
slim
.
get_variables_to_restore
(),
sharded
=
True
)
saver
.
restore
(
sess
,
get_checkpoint_path
())
tf
.
logging
.
info
(
'Model restored successfully.'
)
tf
.
compat
.
v1
.
logging
.
info
(
'Model restored successfully.'
)
# Create model signature.
if
export_for_serving
:
input_tensors
=
{
tf
.
saved_model
.
signature_constants
.
CLASSIFY_INPUTS
:
images_orig
tf
.
saved_model
.
CLASSIFY_INPUTS
:
images_orig
}
else
:
input_tensors
=
{
'images'
:
images_orig
}
...
...
@@ -163,21 +165,21 @@ def export_model(export_dir,
dataset
.
max_sequence_length
)):
output_tensors
[
'attention_mask_%d'
%
i
]
=
t
signature_outputs
=
model_export_lib
.
build_tensor_info
(
output_tensors
)
signature_def
=
tf
.
saved_model
.
signature_def_utils
.
build_signature_def
(
signature_def
=
tf
.
compat
.
v1
.
saved_model
.
signature_def_utils
.
build_signature_def
(
signature_inputs
,
signature_outputs
,
tf
.
saved_model
.
signature_constants
.
CLASSIFY_METHOD_NAME
)
tf
.
saved_model
.
CLASSIFY_METHOD_NAME
)
# Save model.
builder
=
tf
.
saved_model
.
builder
.
SavedModelBuilder
(
export_dir
)
builder
=
tf
.
compat
.
v1
.
saved_model
.
builder
.
SavedModelBuilder
(
export_dir
)
builder
.
add_meta_graph_and_variables
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
sess
,
[
tf
.
saved_model
.
SERVING
],
signature_def_map
=
{
tf
.
saved_model
.
signature_constants
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
:
tf
.
saved_model
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
:
signature_def
},
main_op
=
tf
.
tables_initializer
(),
main_op
=
tf
.
compat
.
v1
.
tables_initializer
(),
strip_default_attrs
=
True
)
builder
.
save
()
tf
.
logging
.
info
(
'Model has been exported to %s'
%
export_dir
)
tf
.
compat
.
v1
.
logging
.
info
(
'Model has been exported to %s'
%
export_dir
)
return
signature_def
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment