Commit 2d5b39ad authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #1865 from alexgorban/master

Spatial attention for the Attention OCR model.
parents f679a001 f282f6ef
...@@ -142,6 +142,9 @@ python train.py --dataset_name=newtextdataset ...@@ -142,6 +142,9 @@ python train.py --dataset_name=newtextdataset
Please note that eval.py will also require the same flag. Please note that eval.py will also require the same flag.
To learn how to store a data in the FSNS
format please refer to the https://stackoverflow.com/a/44461910/743658.
2. Define a new dataset format. The model needs the following data to train: 2. Define a new dataset format. The model needs the following data to train:
- images: input images, shape [batch_size x H x W x 3]; - images: input images, shape [batch_size x H x W x 3];
...@@ -176,4 +179,4 @@ The main difference between this version and the version used in the paper - for ...@@ -176,4 +179,4 @@ The main difference between this version and the version used in the paper - for
the paper we used a distributed training with 50 GPU (K80) workers (asynchronous the paper we used a distributed training with 50 GPU (K80) workers (asynchronous
updates), the provided checkpoint was created using this code after ~6 days of updates), the provided checkpoint was created using this code after ~6 days of
training on a single GPU (Titan X) (it reached 81% after 24 hours of training), training on a single GPU (Titan X) (it reached 81% after 24 hours of training),
the coordinate encoding is missing TODO(alexgorban@). the coordinate encoding is disabled by default.
...@@ -55,6 +55,10 @@ SequenceLossParams = collections.namedtuple('SequenceLossParams', [ ...@@ -55,6 +55,10 @@ SequenceLossParams = collections.namedtuple('SequenceLossParams', [
'label_smoothing', 'ignore_nulls', 'average_across_timesteps' 'label_smoothing', 'ignore_nulls', 'average_across_timesteps'
]) ])
EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [
'enabled'
])
def _dict_to_array(id_to_char, default_character): def _dict_to_array(id_to_char, default_character):
num_char_classes = max(id_to_char.keys()) + 1 num_char_classes = max(id_to_char.keys()) + 1
...@@ -162,7 +166,8 @@ class Model(object): ...@@ -162,7 +166,8 @@ class Model(object):
SequenceLossParams( SequenceLossParams(
label_smoothing=0.1, label_smoothing=0.1,
ignore_nulls=True, ignore_nulls=True,
average_across_timesteps=False) average_across_timesteps=False),
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False)
} }
def set_mparam(self, function, **kwargs): def set_mparam(self, function, **kwargs):
...@@ -293,6 +298,30 @@ class Model(object): ...@@ -293,6 +298,30 @@ class Model(object):
scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length)) scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length))
return ids, log_prob, scores return ids, log_prob, scores
def encode_coordinates_fn(self, net):
"""Adds one-hot encoding of coordinates to different views in the networks.
For each "pixel" of a feature map it adds a onehot encoded x and y
coordinates.
Args:
net: a tensor of shape=[batch_size, height, width, num_features]
Returns:
a tensor with the same height and width, but altered feature_size.
"""
mparams = self._mparams['encode_coordinates_fn']
if mparams.enabled:
batch_size, h, w, _ = net.shape.as_list()
x, y = tf.meshgrid(tf.range(w), tf.range(h))
w_loc = slim.one_hot_encoding(x, num_classes=w)
h_loc = slim.one_hot_encoding(y, num_classes=h)
loc = tf.concat([h_loc, w_loc], 2)
loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1])
return tf.concat([net, loc], 3)
else:
return net
def create_base(self, def create_base(self,
images, images,
labels_one_hot, labels_one_hot,
...@@ -324,6 +353,9 @@ class Model(object): ...@@ -324,6 +353,9 @@ class Model(object):
] ]
logging.debug('Conv tower: %s', nets[0]) logging.debug('Conv tower: %s', nets[0])
nets = [self.encode_coordinates_fn(net) for net in nets]
logging.debug('Conv tower w/ encoded coordinates: %s', nets[0])
net = self.pool_views_fn(nets) net = self.pool_views_fn(nets)
logging.debug('Pooled views: %s', net) logging.debug('Pooled views: %s', net)
......
...@@ -62,8 +62,9 @@ class ModelTest(tf.test.TestCase): ...@@ -62,8 +62,9 @@ class ModelTest(tf.test.TestCase):
self.rng.randint(low=0, high=255, self.rng.randint(low=0, high=255,
size=self.images_shape).astype('float32'), size=self.images_shape).astype('float32'),
name='input_node') name='input_node')
self.fake_conv_tower_np = tf.constant( self.fake_conv_tower_np = self.rng.randn(
self.rng.randn(*self.conv_tower_shape).astype('float32')) *self.conv_tower_shape).astype('float32')
self.fake_conv_tower = tf.constant(self.fake_conv_tower_np)
self.fake_logits = tf.constant( self.fake_logits = tf.constant(
self.rng.randn(*self.chars_logit_shape).astype('float32')) self.rng.randn(*self.chars_logit_shape).astype('float32'))
self.fake_labels = tf.constant( self.fake_labels = tf.constant(
...@@ -162,6 +163,87 @@ class ModelTest(tf.test.TestCase): ...@@ -162,6 +163,87 @@ class ModelTest(tf.test.TestCase):
# This test checks that the loss function is 'runnable'. # This test checks that the loss function is 'runnable'.
self.assertEqual(loss_np.shape, tuple()) self.assertEqual(loss_np.shape, tuple())
def encode_coordinates_alt(self, net):
"""An alternative implemenation for the encoding coordinates.
Args:
net: a tensor of shape=[batch_size, height, width, num_features]
Returns:
a list of tensors with encoded image coordinates in them.
"""
batch_size, h, w, _ = net.shape.as_list()
h_loc = [
tf.tile(
tf.reshape(
tf.contrib.layers.one_hot_encoding(
tf.constant([i]), num_classes=h), [h, 1]), [1, w])
for i in xrange(h)
]
h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2)
w_loc = [
tf.tile(
tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w),
[h, 1]) for i in xrange(w)
]
w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2)
loc = tf.concat([h_loc, w_loc], 2)
loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1])
return tf.concat([net, loc], 3)
def test_encoded_coordinates_have_correct_shape(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
batch_size, height, width, feature_size = self.conv_tower_shape
self.assertEqual(conv_w_coords.shape, (batch_size, height, width,
feature_size + height + width))
def test_disabled_coordinate_encoding_returns_features_unchanged(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=False)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
self.assertAllEqual(conv_w_coords, self.fake_conv_tower_np)
def test_coordinate_encoding_is_correct_for_simple_example(self):
shape = (1, 2, 3, 4) # batch_size, height, width, feature_size
fake_conv_tower = tf.constant(2 * np.ones(shape), dtype=tf.float32)
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
# Original features
self.assertAllEqual(conv_w_coords[0, :, :, :4],
[[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]],
[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]])
# Encoded coordinates
self.assertAllEqual(conv_w_coords[0, :, :, 4:],
[[[1, 0, 1, 0, 0], [1, 0, 0, 1, 0], [1, 0, 0, 0, 1]],
[[0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 1, 0, 0, 1]]])
def test_alt_implementation_of_coordinate_encoding_returns_same_values(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
conv_w_coords_alt_tf = self.encode_coordinates_alt(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords_tf, conv_w_coords_alt_tf = sess.run(
[conv_w_coords_tf, conv_w_coords_alt_tf])
self.assertAllEqual(conv_w_coords_tf, conv_w_coords_alt_tf)
class CharsetMapperTest(tf.test.TestCase): class CharsetMapperTest(tf.test.TestCase):
def test_text_corresponds_to_ids(self): def test_text_corresponds_to_ids(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment