Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2d5b39ad
Commit
2d5b39ad
authored
Jul 12, 2017
by
Neal Wu
Committed by
GitHub
Jul 12, 2017
Browse files
Merge pull request #1865 from alexgorban/master
Spatial attention for the Attention OCR model.
parents
f679a001
f282f6ef
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
121 additions
and
4 deletions
+121
-4
attention_ocr/README.md
attention_ocr/README.md
+4
-1
attention_ocr/python/model.py
attention_ocr/python/model.py
+33
-1
attention_ocr/python/model_test.py
attention_ocr/python/model_test.py
+84
-2
No files found.
attention_ocr/README.md
View file @
2d5b39ad
...
@@ -142,6 +142,9 @@ python train.py --dataset_name=newtextdataset
...
@@ -142,6 +142,9 @@ python train.py --dataset_name=newtextdataset
Please note that eval.py will also require the same flag.
Please note that eval.py will also require the same flag.
To learn how to store a data in the FSNS
format please refer to the https://stackoverflow.com/a/44461910/743658.
2.
Define a new dataset format. The model needs the following data to train:
2.
Define a new dataset format. The model needs the following data to train:
-
images: input images, shape [batch_size x H x W x 3];
-
images: input images, shape [batch_size x H x W x 3];
...
@@ -176,4 +179,4 @@ The main difference between this version and the version used in the paper - for
...
@@ -176,4 +179,4 @@ The main difference between this version and the version used in the paper - for
the paper we used a distributed training with 50 GPU (K80) workers (asynchronous
the paper we used a distributed training with 50 GPU (K80) workers (asynchronous
updates), the provided checkpoint was created using this code after ~6 days of
updates), the provided checkpoint was created using this code after ~6 days of
training on a single GPU (Titan X) (it reached 81% after 24 hours of training),
training on a single GPU (Titan X) (it reached 81% after 24 hours of training),
the coordinate encoding is
m
is
sing TODO(alexgorban@)
.
the coordinate encoding is
d
is
abled by default
.
attention_ocr/python/model.py
View file @
2d5b39ad
...
@@ -55,6 +55,10 @@ SequenceLossParams = collections.namedtuple('SequenceLossParams', [
...
@@ -55,6 +55,10 @@ SequenceLossParams = collections.namedtuple('SequenceLossParams', [
'label_smoothing'
,
'ignore_nulls'
,
'average_across_timesteps'
'label_smoothing'
,
'ignore_nulls'
,
'average_across_timesteps'
])
])
EncodeCoordinatesParams
=
collections
.
namedtuple
(
'EncodeCoordinatesParams'
,
[
'enabled'
])
def
_dict_to_array
(
id_to_char
,
default_character
):
def
_dict_to_array
(
id_to_char
,
default_character
):
num_char_classes
=
max
(
id_to_char
.
keys
())
+
1
num_char_classes
=
max
(
id_to_char
.
keys
())
+
1
...
@@ -162,7 +166,8 @@ class Model(object):
...
@@ -162,7 +166,8 @@ class Model(object):
SequenceLossParams
(
SequenceLossParams
(
label_smoothing
=
0.1
,
label_smoothing
=
0.1
,
ignore_nulls
=
True
,
ignore_nulls
=
True
,
average_across_timesteps
=
False
)
average_across_timesteps
=
False
),
'encode_coordinates_fn'
:
EncodeCoordinatesParams
(
enabled
=
False
)
}
}
def
set_mparam
(
self
,
function
,
**
kwargs
):
def
set_mparam
(
self
,
function
,
**
kwargs
):
...
@@ -293,6 +298,30 @@ class Model(object):
...
@@ -293,6 +298,30 @@ class Model(object):
scores
=
tf
.
reshape
(
selected_scores
,
shape
=
(
-
1
,
self
.
_params
.
seq_length
))
scores
=
tf
.
reshape
(
selected_scores
,
shape
=
(
-
1
,
self
.
_params
.
seq_length
))
return
ids
,
log_prob
,
scores
return
ids
,
log_prob
,
scores
def
encode_coordinates_fn
(
self
,
net
):
"""Adds one-hot encoding of coordinates to different views in the networks.
For each "pixel" of a feature map it adds a onehot encoded x and y
coordinates.
Args:
net: a tensor of shape=[batch_size, height, width, num_features]
Returns:
a tensor with the same height and width, but altered feature_size.
"""
mparams
=
self
.
_mparams
[
'encode_coordinates_fn'
]
if
mparams
.
enabled
:
batch_size
,
h
,
w
,
_
=
net
.
shape
.
as_list
()
x
,
y
=
tf
.
meshgrid
(
tf
.
range
(
w
),
tf
.
range
(
h
))
w_loc
=
slim
.
one_hot_encoding
(
x
,
num_classes
=
w
)
h_loc
=
slim
.
one_hot_encoding
(
y
,
num_classes
=
h
)
loc
=
tf
.
concat
([
h_loc
,
w_loc
],
2
)
loc
=
tf
.
tile
(
tf
.
expand_dims
(
loc
,
0
),
[
batch_size
,
1
,
1
,
1
])
return
tf
.
concat
([
net
,
loc
],
3
)
else
:
return
net
def
create_base
(
self
,
def
create_base
(
self
,
images
,
images
,
labels_one_hot
,
labels_one_hot
,
...
@@ -324,6 +353,9 @@ class Model(object):
...
@@ -324,6 +353,9 @@ class Model(object):
]
]
logging
.
debug
(
'Conv tower: %s'
,
nets
[
0
])
logging
.
debug
(
'Conv tower: %s'
,
nets
[
0
])
nets
=
[
self
.
encode_coordinates_fn
(
net
)
for
net
in
nets
]
logging
.
debug
(
'Conv tower w/ encoded coordinates: %s'
,
nets
[
0
])
net
=
self
.
pool_views_fn
(
nets
)
net
=
self
.
pool_views_fn
(
nets
)
logging
.
debug
(
'Pooled views: %s'
,
net
)
logging
.
debug
(
'Pooled views: %s'
,
net
)
...
...
attention_ocr/python/model_test.py
View file @
2d5b39ad
...
@@ -62,8 +62,9 @@ class ModelTest(tf.test.TestCase):
...
@@ -62,8 +62,9 @@ class ModelTest(tf.test.TestCase):
self
.
rng
.
randint
(
low
=
0
,
high
=
255
,
self
.
rng
.
randint
(
low
=
0
,
high
=
255
,
size
=
self
.
images_shape
).
astype
(
'float32'
),
size
=
self
.
images_shape
).
astype
(
'float32'
),
name
=
'input_node'
)
name
=
'input_node'
)
self
.
fake_conv_tower_np
=
tf
.
constant
(
self
.
fake_conv_tower_np
=
self
.
rng
.
randn
(
self
.
rng
.
randn
(
*
self
.
conv_tower_shape
).
astype
(
'float32'
))
*
self
.
conv_tower_shape
).
astype
(
'float32'
)
self
.
fake_conv_tower
=
tf
.
constant
(
self
.
fake_conv_tower_np
)
self
.
fake_logits
=
tf
.
constant
(
self
.
fake_logits
=
tf
.
constant
(
self
.
rng
.
randn
(
*
self
.
chars_logit_shape
).
astype
(
'float32'
))
self
.
rng
.
randn
(
*
self
.
chars_logit_shape
).
astype
(
'float32'
))
self
.
fake_labels
=
tf
.
constant
(
self
.
fake_labels
=
tf
.
constant
(
...
@@ -162,6 +163,87 @@ class ModelTest(tf.test.TestCase):
...
@@ -162,6 +163,87 @@ class ModelTest(tf.test.TestCase):
# This test checks that the loss function is 'runnable'.
# This test checks that the loss function is 'runnable'.
self
.
assertEqual
(
loss_np
.
shape
,
tuple
())
self
.
assertEqual
(
loss_np
.
shape
,
tuple
())
def
encode_coordinates_alt
(
self
,
net
):
"""An alternative implemenation for the encoding coordinates.
Args:
net: a tensor of shape=[batch_size, height, width, num_features]
Returns:
a list of tensors with encoded image coordinates in them.
"""
batch_size
,
h
,
w
,
_
=
net
.
shape
.
as_list
()
h_loc
=
[
tf
.
tile
(
tf
.
reshape
(
tf
.
contrib
.
layers
.
one_hot_encoding
(
tf
.
constant
([
i
]),
num_classes
=
h
),
[
h
,
1
]),
[
1
,
w
])
for
i
in
xrange
(
h
)
]
h_loc
=
tf
.
concat
([
tf
.
expand_dims
(
t
,
2
)
for
t
in
h_loc
],
2
)
w_loc
=
[
tf
.
tile
(
tf
.
contrib
.
layers
.
one_hot_encoding
(
tf
.
constant
([
i
]),
num_classes
=
w
),
[
h
,
1
])
for
i
in
xrange
(
w
)
]
w_loc
=
tf
.
concat
([
tf
.
expand_dims
(
t
,
2
)
for
t
in
w_loc
],
2
)
loc
=
tf
.
concat
([
h_loc
,
w_loc
],
2
)
loc
=
tf
.
tile
(
tf
.
expand_dims
(
loc
,
0
),
[
batch_size
,
1
,
1
,
1
])
return
tf
.
concat
([
net
,
loc
],
3
)
def
test_encoded_coordinates_have_correct_shape
(
self
):
model
=
self
.
create_model
()
model
.
set_mparam
(
'encode_coordinates_fn'
,
enabled
=
True
)
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
self
.
fake_conv_tower
)
with
self
.
test_session
()
as
sess
:
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
)
batch_size
,
height
,
width
,
feature_size
=
self
.
conv_tower_shape
self
.
assertEqual
(
conv_w_coords
.
shape
,
(
batch_size
,
height
,
width
,
feature_size
+
height
+
width
))
def
test_disabled_coordinate_encoding_returns_features_unchanged
(
self
):
model
=
self
.
create_model
()
model
.
set_mparam
(
'encode_coordinates_fn'
,
enabled
=
False
)
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
self
.
fake_conv_tower
)
with
self
.
test_session
()
as
sess
:
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
)
self
.
assertAllEqual
(
conv_w_coords
,
self
.
fake_conv_tower_np
)
def
test_coordinate_encoding_is_correct_for_simple_example
(
self
):
shape
=
(
1
,
2
,
3
,
4
)
# batch_size, height, width, feature_size
fake_conv_tower
=
tf
.
constant
(
2
*
np
.
ones
(
shape
),
dtype
=
tf
.
float32
)
model
=
self
.
create_model
()
model
.
set_mparam
(
'encode_coordinates_fn'
,
enabled
=
True
)
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
fake_conv_tower
)
with
self
.
test_session
()
as
sess
:
conv_w_coords
=
sess
.
run
(
conv_w_coords_tf
)
# Original features
self
.
assertAllEqual
(
conv_w_coords
[
0
,
:,
:,
:
4
],
[[[
2
,
2
,
2
,
2
],
[
2
,
2
,
2
,
2
],
[
2
,
2
,
2
,
2
]],
[[
2
,
2
,
2
,
2
],
[
2
,
2
,
2
,
2
],
[
2
,
2
,
2
,
2
]]])
# Encoded coordinates
self
.
assertAllEqual
(
conv_w_coords
[
0
,
:,
:,
4
:],
[[[
1
,
0
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
0
],
[
1
,
0
,
0
,
0
,
1
]],
[[
0
,
1
,
1
,
0
,
0
],
[
0
,
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
0
,
1
]]])
def
test_alt_implementation_of_coordinate_encoding_returns_same_values
(
self
):
model
=
self
.
create_model
()
model
.
set_mparam
(
'encode_coordinates_fn'
,
enabled
=
True
)
conv_w_coords_tf
=
model
.
encode_coordinates_fn
(
self
.
fake_conv_tower
)
conv_w_coords_alt_tf
=
self
.
encode_coordinates_alt
(
self
.
fake_conv_tower
)
with
self
.
test_session
()
as
sess
:
conv_w_coords_tf
,
conv_w_coords_alt_tf
=
sess
.
run
(
[
conv_w_coords_tf
,
conv_w_coords_alt_tf
])
self
.
assertAllEqual
(
conv_w_coords_tf
,
conv_w_coords_alt_tf
)
class
CharsetMapperTest
(
tf
.
test
.
TestCase
):
class
CharsetMapperTest
(
tf
.
test
.
TestCase
):
def
test_text_corresponds_to_ids
(
self
):
def
test_text_corresponds_to_ids
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment