"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fbff43acc9f52aec18e27806cc258a592f8b53f6"
Commit 58a5da7b authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Add spatial attention.

The spatial attention described in the paper was ported from the internal
implementation and is disabled by default.
parent 1a392371
...@@ -142,6 +142,9 @@ python train.py --dataset_name=newtextdataset ...@@ -142,6 +142,9 @@ python train.py --dataset_name=newtextdataset
Please note that eval.py will also require the same flag. Please note that eval.py will also require the same flag.
To learn how to store a data in the FSNS
format please refer to the https://stackoverflow.com/a/44461910/743658.
2. Define a new dataset format. The model needs the following data to train: 2. Define a new dataset format. The model needs the following data to train:
- images: input images, shape [batch_size x H x W x 3]; - images: input images, shape [batch_size x H x W x 3];
...@@ -176,4 +179,4 @@ The main difference between this version and the version used in the paper - for ...@@ -176,4 +179,4 @@ The main difference between this version and the version used in the paper - for
the paper we used a distributed training with 50 GPU (K80) workers (asynchronous the paper we used a distributed training with 50 GPU (K80) workers (asynchronous
updates), the provided checkpoint was created using this code after ~6 days of updates), the provided checkpoint was created using this code after ~6 days of
training on a single GPU (Titan X) (it reached 81% after 24 hours of training), training on a single GPU (Titan X) (it reached 81% after 24 hours of training),
the coordinate encoding is missing TODO(alexgorban@). the coordinate encoding is disabled by default.
...@@ -55,6 +55,10 @@ SequenceLossParams = collections.namedtuple('SequenceLossParams', [ ...@@ -55,6 +55,10 @@ SequenceLossParams = collections.namedtuple('SequenceLossParams', [
'label_smoothing', 'ignore_nulls', 'average_across_timesteps' 'label_smoothing', 'ignore_nulls', 'average_across_timesteps'
]) ])
EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [
'enabled'
])
def _dict_to_array(id_to_char, default_character): def _dict_to_array(id_to_char, default_character):
num_char_classes = max(id_to_char.keys()) + 1 num_char_classes = max(id_to_char.keys()) + 1
...@@ -162,7 +166,8 @@ class Model(object): ...@@ -162,7 +166,8 @@ class Model(object):
SequenceLossParams( SequenceLossParams(
label_smoothing=0.1, label_smoothing=0.1,
ignore_nulls=True, ignore_nulls=True,
average_across_timesteps=False) average_across_timesteps=False),
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False)
} }
def set_mparam(self, function, **kwargs): def set_mparam(self, function, **kwargs):
...@@ -293,6 +298,30 @@ class Model(object): ...@@ -293,6 +298,30 @@ class Model(object):
scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length)) scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length))
return ids, log_prob, scores return ids, log_prob, scores
def encode_coordinates_fn(self, net):
"""Adds one-hot encoding of coordinates to different views in the networks.
For each "pixel" of a feature map it adds a onehot encoded x and y
coordinates.
Args:
net: a tensor of shape=[batch_size, height, width, num_features]
Returns:
a tensor with the same height and width, but altered feature_size.
"""
mparams = self._mparams['encode_coordinates_fn']
if mparams.enabled:
batch_size, h, w, _ = net.shape.as_list()
x, y = tf.meshgrid(tf.range(w), tf.range(h))
w_loc = slim.one_hot_encoding(x, num_classes=w)
h_loc = slim.one_hot_encoding(y, num_classes=h)
loc = tf.concat([h_loc, w_loc], 2)
loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1])
return tf.concat([net, loc], 3)
else:
return net
def create_base(self, def create_base(self,
images, images,
labels_one_hot, labels_one_hot,
...@@ -324,6 +353,9 @@ class Model(object): ...@@ -324,6 +353,9 @@ class Model(object):
] ]
logging.debug('Conv tower: %s', nets[0]) logging.debug('Conv tower: %s', nets[0])
nets = [self.encode_coordinates_fn(net) for net in nets]
logging.debug('Conv tower w/ encoded coordinates: %s', nets[0])
net = self.pool_views_fn(nets) net = self.pool_views_fn(nets)
logging.debug('Pooled views: %s', net) logging.debug('Pooled views: %s', net)
......
...@@ -62,8 +62,9 @@ class ModelTest(tf.test.TestCase): ...@@ -62,8 +62,9 @@ class ModelTest(tf.test.TestCase):
self.rng.randint(low=0, high=255, self.rng.randint(low=0, high=255,
size=self.images_shape).astype('float32'), size=self.images_shape).astype('float32'),
name='input_node') name='input_node')
self.fake_conv_tower_np = tf.constant( self.fake_conv_tower_np = self.rng.randn(
self.rng.randn(*self.conv_tower_shape).astype('float32')) *self.conv_tower_shape).astype('float32')
self.fake_conv_tower = tf.constant(self.fake_conv_tower_np)
self.fake_logits = tf.constant( self.fake_logits = tf.constant(
self.rng.randn(*self.chars_logit_shape).astype('float32')) self.rng.randn(*self.chars_logit_shape).astype('float32'))
self.fake_labels = tf.constant( self.fake_labels = tf.constant(
...@@ -162,6 +163,87 @@ class ModelTest(tf.test.TestCase): ...@@ -162,6 +163,87 @@ class ModelTest(tf.test.TestCase):
# This test checks that the loss function is 'runnable'. # This test checks that the loss function is 'runnable'.
self.assertEqual(loss_np.shape, tuple()) self.assertEqual(loss_np.shape, tuple())
def encode_coordinates_alt(self, net):
"""An alternative implemenation for the encoding coordinates.
Args:
net: a tensor of shape=[batch_size, height, width, num_features]
Returns:
a list of tensors with encoded image coordinates in them.
"""
batch_size, h, w, _ = net.shape.as_list()
h_loc = [
tf.tile(
tf.reshape(
tf.contrib.layers.one_hot_encoding(
tf.constant([i]), num_classes=h), [h, 1]), [1, w])
for i in xrange(h)
]
h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2)
w_loc = [
tf.tile(
tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w),
[h, 1]) for i in xrange(w)
]
w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2)
loc = tf.concat([h_loc, w_loc], 2)
loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1])
return tf.concat([net, loc], 3)
def test_encoded_coordinates_have_correct_shape(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
batch_size, height, width, feature_size = self.conv_tower_shape
self.assertEqual(conv_w_coords.shape, (batch_size, height, width,
feature_size + height + width))
def test_disabled_coordinate_encoding_returns_features_unchanged(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=False)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
self.assertAllEqual(conv_w_coords, self.fake_conv_tower_np)
def test_coordinate_encoding_is_correct_for_simple_example(self):
shape = (1, 2, 3, 4) # batch_size, height, width, feature_size
fake_conv_tower = tf.constant(2 * np.ones(shape), dtype=tf.float32)
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
# Original features
self.assertAllEqual(conv_w_coords[0, :, :, :4],
[[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]],
[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]])
# Encoded coordinates
self.assertAllEqual(conv_w_coords[0, :, :, 4:],
[[[1, 0, 1, 0, 0], [1, 0, 0, 1, 0], [1, 0, 0, 0, 1]],
[[0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 1, 0, 0, 1]]])
def test_alt_implementation_of_coordinate_encoding_returns_same_values(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
conv_w_coords_alt_tf = self.encode_coordinates_alt(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords_tf, conv_w_coords_alt_tf = sess.run(
[conv_w_coords_tf, conv_w_coords_alt_tf])
self.assertAllEqual(conv_w_coords_tf, conv_w_coords_alt_tf)
class CharsetMapperTest(tf.test.TestCase): class CharsetMapperTest(tf.test.TestCase):
def test_text_corresponds_to_ids(self): def test_text_corresponds_to_ids(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment