model.py 27.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to build the Attention OCR model.

Usage example:
  ocr_model = model.Model(num_char_classes, seq_length, num_of_views)

  data = ... # create namedtuple InputEndpoints
  endpoints = model.create_base(data.images, data.labels_one_hot)
  # endpoints.predicted_chars is a tensor with predicted character codes.
  total_loss = model.create_loss(data, endpoints)
"""
import sys
import collections
import logging
28
import numpy as np
29
30
31
32
33
34
35
36
37
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim.nets import inception

import metrics
import sequence_layers
import utils

OutputEndpoints = collections.namedtuple('OutputEndpoints', [
38
39
40
    'chars_logit', 'chars_log_prob', 'predicted_chars', 'predicted_scores',
    'predicted_text', 'predicted_length', 'predicted_conf',
    'normalized_seq_conf'
41
42
43
])

# TODO(gorban): replace with tf.HParams when it is released.
44
45
ModelParams = collections.namedtuple(
    'ModelParams', ['num_char_classes', 'seq_length', 'num_views', 'null_code'])
46
47
48
49

ConvTowerParams = collections.namedtuple('ConvTowerParams', ['final_endpoint'])

SequenceLogitsParams = collections.namedtuple('SequenceLogitsParams', [
50
51
    'use_attention', 'use_autoregression', 'num_lstm_units', 'weight_decay',
    'lstm_state_clip_value'
52
53
])

54
55
56
SequenceLossParams = collections.namedtuple(
    'SequenceLossParams',
    ['label_smoothing', 'ignore_nulls', 'average_across_timesteps'])
57

58
59
EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams',
                                                 ['enabled'])
Alexander Gorban's avatar
Alexander Gorban committed
60

61
62
63
64

def _dict_to_array(id_to_char, default_character):
  num_char_classes = max(id_to_char.keys()) + 1
  array = [default_character] * num_char_classes
65
  for k, v in id_to_char.items():
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    array[k] = v
  return array


class CharsetMapper(object):
  """A simple class to map tensor ids into strings.

    It works only when the character set is 1:1 mapping between individual
    characters and individual ids.

    Make sure you call tf.tables_initializer().run() as part of the init op.
    """

  def __init__(self, charset, default_character='?'):
    """Creates a lookup table.

    Args:
      charset: a dictionary with id-to-character mapping.
    """
    mapping_strings = tf.constant(_dict_to_array(charset, default_character))
    self.table = tf.contrib.lookup.index_to_string_table_from_tensor(
87
        mapping=mapping_strings, default_value=default_character)
88
89
90
91
92
93

  def get_text(self, ids):
    """Returns a string corresponding to a sequence of character ids.

        Args:
          ids: a tensor with shape [batch_size, max_sequence_length]
94
    """
95
96
    return tf.strings.reduce_join(
        inputs=self.table.lookup(tf.cast(ids, dtype=tf.int64)), axis=1)
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112


def get_softmax_loss_fn(label_smoothing):
  """Returns sparse or dense loss function depending on the label_smoothing.

    Args:
      label_smoothing: weight for label smoothing

    Returns:
      a function which takes labels and predictions as arguments and returns
      a softmax loss for the selected type of labels (sparse or dense).
    """
  if label_smoothing > 0:

    def loss_fn(labels, logits):
      return (tf.nn.softmax_cross_entropy_with_logits(
113
          logits=logits, labels=tf.stop_gradient(labels)))
114
115
116
117
  else:

    def loss_fn(labels, logits):
      return tf.nn.sparse_softmax_cross_entropy_with_logits(
118
          logits=logits, labels=labels)
119
120
121
122

  return loss_fn


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def get_tensor_dimensions(tensor):
  """Returns the shape components of a 4D tensor with variable batch size.

  Args:
    tensor : A 4D tensor, whose last 3 dimensions are known at graph
      construction time.

  Returns:
    batch_size : The first dimension as a tensor object.
    height : The second dimension as a scalar value.
    width : The third dimension as a scalar value.
    num_features : The forth dimension as a scalar value.

  Raises:
    ValueError: if input tensor does not have 4 dimensions.
  """
  if len(tensor.get_shape().dims) != 4:
    raise ValueError(
        'Incompatible shape: len(tensor.get_shape().dims) != 4 (%d != 4)' %
        len(tensor.get_shape().dims))
143
  batch_size = tf.shape(input=tensor)[0]
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
  height = tensor.get_shape().dims[1].value
  width = tensor.get_shape().dims[2].value
  num_features = tensor.get_shape().dims[3].value
  return batch_size, height, width, num_features


def lookup_indexed_value(indices, row_vecs):
  """Lookup values in each row of 'row_vecs' indexed by 'indices'.

  For each sample in the batch, look up the element for the corresponding
  index.

  Args:
    indices : A tensor of shape (batch, )
    row_vecs : A tensor of shape [batch, depth]

  Returns:
    A tensor of shape (batch, ) formed by row_vecs[i, indices[i]].
  """
  gather_indices = tf.stack((tf.range(
164
165
      tf.shape(input=row_vecs)[0], dtype=tf.int32), tf.cast(indices, tf.int32)),
      axis=1)
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
  return tf.gather_nd(row_vecs, gather_indices)


@utils.ConvertAllInputsToTensors
def max_char_logprob_cumsum(char_log_prob):
  """Computes the cumulative sum of character logprob for all sequence lengths.

  Args:
    char_log_prob: A tensor of shape [batch x seq_length x num_char_classes]
      with log probabilities of a character.

  Returns:
    A tensor of shape [batch x (seq_length+1)] where each element x[_, j] is
    the sum of the max char logprob for all positions upto j.
    Note this duplicates the final column and produces (seq_length+1) columns
    so the same function can be used regardless whether use_length_predictions
    is true or false.
  """
184
  max_char_log_prob = tf.reduce_max(input_tensor=char_log_prob, axis=2)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
  # For an input array [a, b, c]) tf.cumsum returns [a, a + b, a + b + c] if
  # exclusive set to False (default).
  return tf.cumsum(max_char_log_prob, axis=1, exclusive=False)


def find_length_by_null(predicted_chars, null_code):
  """Determine sequence length by finding null_code among predicted char IDs.

  Given the char class ID for each position, compute the sequence length.
  Note that this function computes this based on the number of null_code,
  instead of the position of the first null_code.

  Args:
    predicted_chars: A tensor of [batch x seq_length] where each element stores
      the char class ID with max probability;
    null_code: an int32, character id for the NULL.

  Returns:
    A [batch, ] tensor which stores the sequence length for each sample.
  """
  return tf.reduce_sum(
206
      input_tensor=tf.cast(tf.not_equal(null_code, predicted_chars), tf.int32), axis=1)
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250


def axis_pad(tensor, axis, before=0, after=0, constant_values=0.0):
  """Pad a tensor with the specified values along a single axis.

  Args:
    tensor: a Tensor;
    axis: the dimension to add pad along to;
    before: number of values to add before the contents of tensor in the
      selected dimension;
    after: number of values to add after the contents of tensor in the selected
      dimension;
    constant_values: the scalar pad value to use. Must be same type as tensor.

  Returns:
    A Tensor. Has the same type as the input tensor, but with a changed shape
    along the specified dimension.
  """
  if before == 0 and after == 0:
    return tensor
  ndims = tensor.shape.ndims
  padding_size = np.zeros((ndims, 2), dtype='int32')
  padding_size[axis] = before, after
  return tf.pad(
      tensor=tensor,
      paddings=tf.constant(padding_size),
      constant_values=constant_values)


def null_based_length_prediction(chars_log_prob, null_code):
  """Computes length and confidence of prediction based on positions of NULLs.

  Args:
    chars_log_prob: A tensor of shape [batch x seq_length x num_char_classes]
      with log probabilities of a character;
    null_code: an int32, character id for the NULL.

  Returns:
    A tuple (text_log_prob, predicted_length), where
    text_log_prob - is a tensor of the same shape as length_log_prob.
    Element #0 of the output corresponds to probability of the empty string,
    element #seq_length - is the probability of length=seq_length.
    predicted_length is a tensor with shape [batch].
  """
251
252
  predicted_chars = tf.cast(
      tf.argmax(input=chars_log_prob, axis=2), dtype=tf.int32)
253
254
255
256
257
258
259
  # We do right pad to support sequences with seq_length elements.
  text_log_prob = max_char_logprob_cumsum(
      axis_pad(chars_log_prob, axis=1, after=1))
  predicted_length = find_length_by_null(predicted_chars, null_code)
  return text_log_prob, predicted_length


260
261
262
263
class Model(object):
  """Class to create the Attention OCR Model."""

  def __init__(self,
264
265
266
267
268
269
               num_char_classes,
               seq_length,
               num_views,
               null_code,
               mparams=None,
               charset=None):
270
271
272
273
274
275
    """Initialized model parameters.

    Args:
      num_char_classes: size of character set.
      seq_length: number of characters in a sequence.
      num_views: Number of views (conv towers) to use.
276
277
278
279
      null_code: A character code corresponding to a character which indicates
        end of a sequence.
      mparams: a dictionary with hyper parameters for methods,  keys - function
        names, values - corresponding namedtuples.
280
      charset: an optional dictionary with a mapping between character ids and
281
282
        utf8 strings. If specified the OutputEndpoints.predicted_text will utf8
        encoded strings corresponding to the character ids returned by
283
        OutputEndpoints.predicted_chars (by default the predicted_text contains
284
        an empty vector).
285
        NOTE: Make sure you call tf.tables_initializer().run() if the charset
286
          specified.
287
288
289
    """
    super(Model, self).__init__()
    self._params = ModelParams(
290
291
292
293
        num_char_classes=num_char_classes,
        seq_length=seq_length,
        num_views=num_views,
        null_code=null_code)
294
295
296
    self._mparams = self.default_mparams()
    if mparams:
      self._mparams.update(mparams)
297
    self._charset = charset
298
299
300

  def default_mparams(self):
    return {
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        'conv_tower_fn':
            ConvTowerParams(final_endpoint='Mixed_5d'),
        'sequence_logit_fn':
            SequenceLogitsParams(
                use_attention=True,
                use_autoregression=True,
                num_lstm_units=256,
                weight_decay=0.00004,
                lstm_state_clip_value=10.0),
        'sequence_loss_fn':
            SequenceLossParams(
                label_smoothing=0.1,
                ignore_nulls=True,
                average_across_timesteps=False),
        'encode_coordinates_fn':
            EncodeCoordinatesParams(enabled=False)
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    }

  def set_mparam(self, function, **kwargs):
    self._mparams[function] = self._mparams[function]._replace(**kwargs)

  def conv_tower_fn(self, images, is_training=True, reuse=None):
    """Computes convolutional features using the InceptionV3 model.

    Args:
      images: A tensor of shape [batch_size, height, width, channels].
      is_training: whether is training or not.
      reuse: whether or not the network and its variables should be reused. To
        be able to reuse 'scope' must be given.

    Returns:
      A tensor of shape [batch_size, OH, OW, N], where OWxOH is resolution of
      output feature map and N is number of output features (depends on the
      network architecture).
    """
    mparams = self._mparams['conv_tower_fn']
    logging.debug('Using final_endpoint=%s', mparams.final_endpoint)
338
    with tf.compat.v1.variable_scope('conv_tower_fn/INCE'):
339
      if reuse:
340
        tf.compat.v1.get_variable_scope().reuse_variables()
341
342
343
344
      with slim.arg_scope(inception.inception_v3_arg_scope()):
        with slim.arg_scope([slim.batch_norm, slim.dropout],
                            is_training=is_training):
          net, _ = inception.inception_v3_base(
345
              images, final_endpoint=mparams.final_endpoint)
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
      return net

  def _create_lstm_inputs(self, net):
    """Splits an input tensor into a list of tensors (features).

    Args:
      net: A feature map of shape [batch_size, num_features, feature_size].

    Raises:
      AssertionError: if num_features is less than seq_length.

    Returns:
      A list with seq_length tensors of shape [batch_size, feature_size]
    """
    num_features = net.get_shape().dims[1].value
    if num_features < self._params.seq_length:
362
363
364
365
      raise AssertionError(
          'Incorrect dimension #1 of input tensor'
          ' %d should be bigger than %d (shape=%s)' %
          (num_features, self._params.seq_length, net.get_shape()))
366
367
368
369
370
371
372
373
374
375
    elif num_features > self._params.seq_length:
      logging.warning('Ignoring some features: use %d of %d (shape=%s)',
                      self._params.seq_length, num_features, net.get_shape())
      net = tf.slice(net, [0, 0, 0], [-1, self._params.seq_length, -1])

    return tf.unstack(net, axis=1)

  def sequence_logit_fn(self, net, labels_one_hot):
    mparams = self._mparams['sequence_logit_fn']
    # TODO(gorban): remove /alias suffixes from the scopes.
376
    with tf.compat.v1.variable_scope('sequence_logit_fn/SQLR'):
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
      layer_class = sequence_layers.get_layer_class(mparams.use_attention,
                                                    mparams.use_autoregression)
      layer = layer_class(net, labels_one_hot, self._params, mparams)
      return layer.create_logits()

  def max_pool_views(self, nets_list):
    """Max pool across all nets in spatial dimensions.

    Args:
      nets_list: A list of 4D tensors with identical size.

    Returns:
      A tensor with the same size as any input tensors.
    """
    batch_size, height, width, num_features = [
392
        d.value for d in nets_list[0].get_shape().dims
393
394
395
    ]
    xy_flat_shape = (batch_size, 1, height * width, num_features)
    nets_for_merge = []
396
    with tf.compat.v1.variable_scope('max_pool_views', values=nets_list):
397
398
399
400
      for net in nets_list:
        nets_for_merge.append(tf.reshape(net, xy_flat_shape))
      merged_net = tf.concat(nets_for_merge, 1)
      net = slim.max_pool2d(
401
          merged_net, kernel_size=[len(nets_list), 1], stride=1)
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
      net = tf.reshape(net, (batch_size, height, width, num_features))
    return net

  def pool_views_fn(self, nets):
    """Combines output of multiple convolutional towers into a single tensor.

    It stacks towers one on top another (in height dim) in a 4x1 grid.
    The order is arbitrary design choice and shouldn't matter much.

    Args:
      nets: list of tensors of shape=[batch_size, height, width, num_features].

    Returns:
      A tensor of shape [batch_size, seq_length, features_size].
    """
417
    with tf.compat.v1.variable_scope('pool_views_fn/STCK'):
418
      net = tf.concat(nets, 1)
419
420
421
      batch_size = tf.shape(input=net)[0]
      image_size = net.get_shape().dims[1].value * \
          net.get_shape().dims[2].value
422
      feature_size = net.get_shape().dims[3].value
423
      return tf.reshape(net, tf.stack([batch_size, image_size, feature_size]))
424
425
426
427
428

  def char_predictions(self, chars_logit):
    """Returns confidence scores (softmax values) for predicted characters.

    Args:
429
430
      chars_logit: chars logits, a tensor with shape [batch_size x seq_length x
        num_char_classes]
431
432
433
434
435
436
437
438
439
440
441
442

    Returns:
      A tuple (ids, log_prob, scores), where:
        ids - predicted characters, a int32 tensor with shape
          [batch_size x seq_length];
        log_prob - a log probability of all characters, a float tensor with
          shape [batch_size, seq_length, num_char_classes];
        scores - corresponding confidence scores for characters, a float
        tensor
          with shape [batch_size x seq_length].
    """
    log_prob = utils.logits_to_log_prob(chars_logit)
443
444
    ids = tf.cast(tf.argmax(input=log_prob, axis=2),
                  name='predicted_chars', dtype=tf.int32)
445
    mask = tf.cast(
446
        slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
447
    all_scores = tf.nn.softmax(chars_logit)
448
449
    selected_scores = tf.boolean_mask(
        tensor=all_scores, mask=mask, name='char_scores')
450
451
452
453
    scores = tf.reshape(
        selected_scores,
        shape=(-1, self._params.seq_length),
        name='predicted_scores')
454
455
    return ids, log_prob, scores

Alexander Gorban's avatar
Alexander Gorban committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
  def encode_coordinates_fn(self, net):
    """Adds one-hot encoding of coordinates to different views in the networks.

    For each "pixel" of a feature map it adds a onehot encoded x and y
    coordinates.

    Args:
      net: a tensor of shape=[batch_size, height, width, num_features]

    Returns:
      a tensor with the same height and width, but altered feature_size.
    """
    mparams = self._mparams['encode_coordinates_fn']
    if mparams.enabled:
470
      batch_size, h, w, _ = get_tensor_dimensions(net)
Alexander Gorban's avatar
Alexander Gorban committed
471
472
473
474
      x, y = tf.meshgrid(tf.range(w), tf.range(h))
      w_loc = slim.one_hot_encoding(x, num_classes=w)
      h_loc = slim.one_hot_encoding(y, num_classes=h)
      loc = tf.concat([h_loc, w_loc], 2)
475
      loc = tf.tile(tf.expand_dims(loc, 0), tf.stack([batch_size, 1, 1, 1]))
Alexander Gorban's avatar
Alexander Gorban committed
476
477
478
479
      return tf.concat([net, loc], 3)
    else:
      return net

480
  def create_base(self,
481
482
483
484
                  images,
                  labels_one_hot,
                  scope='AttentionOcr_v1',
                  reuse=None):
485
486
487
    """Creates a base part of the Model (no gradients, losses or summaries).

    Args:
488
489
      images: A tensor of shape [batch_size, height, width, channels] with pixel
        values in the range [0.0, 1.0].
490
491
492
493
494
495
496
497
498
499
500
      labels_one_hot: Optional (can be None) one-hot encoding for ground truth
        labels. If provided the function will create a model for training.
      scope: Optional variable_scope.
      reuse: whether or not the network and its variables should be reused. To
        be able to reuse 'scope' must be given.

    Returns:
      A named tuple OutputEndpoints.
    """
    logging.debug('images: %s', images)
    is_training = labels_one_hot is not None
501
502
503
504
505

    # Normalize image pixel values to have a symmetrical range around zero.
    images = tf.subtract(images, 0.5)
    images = tf.multiply(images, 2.5)

506
    with tf.compat.v1.variable_scope(scope, reuse=reuse):
507
      views = tf.split(
508
          value=images, num_or_size_splits=self._params.num_views, axis=2)
509
510
511
      logging.debug('Views=%d single view: %s', len(views), views[0])

      nets = [
512
513
          self.conv_tower_fn(v, is_training, reuse=(i != 0))
          for i, v in enumerate(views)
514
515
516
      ]
      logging.debug('Conv tower: %s', nets[0])

Alexander Gorban's avatar
Alexander Gorban committed
517
518
519
      nets = [self.encode_coordinates_fn(net) for net in nets]
      logging.debug('Conv tower w/ encoded coordinates: %s', nets[0])

520
521
522
523
524
525
526
      net = self.pool_views_fn(nets)
      logging.debug('Pooled views: %s', net)

      chars_logit = self.sequence_logit_fn(net, labels_one_hot)
      logging.debug('chars_logit: %s', chars_logit)

      predicted_chars, chars_log_prob, predicted_scores = (
527
          self.char_predictions(chars_logit))
528
529
530
531
532
      if self._charset:
        character_mapper = CharsetMapper(self._charset)
        predicted_text = character_mapper.get_text(predicted_chars)
      else:
        predicted_text = tf.constant([])
533
534
535
536
537
538
539
540
541
542
543
544
545

      text_log_prob, predicted_length = null_based_length_prediction(
          chars_log_prob, self._params.null_code)
      predicted_conf = lookup_indexed_value(predicted_length, text_log_prob)
      # Convert predicted confidence from sum of logs to geometric mean
      normalized_seq_conf = tf.exp(
          tf.divide(predicted_conf,
                    tf.cast(predicted_length + 1, predicted_conf.dtype)),
          name='normalized_seq_conf')
      predicted_conf = tf.identity(predicted_conf, name='predicted_conf')
      predicted_text = tf.identity(predicted_text, name='predicted_text')
      predicted_length = tf.identity(predicted_length, name='predicted_length')

546
    return OutputEndpoints(
547
548
549
550
551
552
553
554
        chars_logit=chars_logit,
        chars_log_prob=chars_log_prob,
        predicted_chars=predicted_chars,
        predicted_scores=predicted_scores,
        predicted_length=predicted_length,
        predicted_text=predicted_text,
        predicted_conf=predicted_conf,
        normalized_seq_conf=normalized_seq_conf)
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572

  def create_loss(self, data, endpoints):
    """Creates all losses required to train the model.

    Args:
      data: InputEndpoints namedtuple.
      endpoints: Model namedtuple.

    Returns:
      Total loss.
    """
    # NOTE: the return value of ModelLoss is not used directly for the
    # gradient computation because under the hood it calls slim.losses.AddLoss,
    # which registers the loss in an internal collection and later returns it
    # as part of GetTotalLoss. We need to use total loss because model may have
    # multiple losses including regularization losses.
    self.sequence_loss_fn(endpoints.chars_logit, data.labels)
    total_loss = slim.losses.get_total_loss()
573
    tf.compat.v1.summary.scalar('TotalLoss', total_loss)
574
575
576
577
578
579
580
581
    return total_loss

  def label_smoothing_regularization(self, chars_labels, weight=0.1):
    """Applies a label smoothing regularization.

    Uses the same method as in https://arxiv.org/abs/1512.00567.

    Args:
582
583
      chars_labels: ground truth ids of charactes, shape=[batch_size,
        seq_length];
584
585
586
587
588
589
      weight: label-smoothing regularization weight.

    Returns:
      A sensor with the same shape as the input.
    """
    one_hot_labels = tf.one_hot(
590
        chars_labels, depth=self._params.num_char_classes, axis=-1)
591
592
593
594
595
596
597
598
599
600
601
    pos_weight = 1.0 - weight
    neg_weight = weight / self._params.num_char_classes
    return one_hot_labels * pos_weight + neg_weight

  def sequence_loss_fn(self, chars_logits, chars_labels):
    """Loss function for char sequence.

    Depending on values of hyper parameters it applies label smoothing and can
    also ignore all null chars after the first one.

    Args:
602
603
604
605
      chars_logits: logits for predicted characters, shape=[batch_size,
        seq_length, num_char_classes];
      chars_labels: ground truth ids of characters, shape=[batch_size,
        seq_length];
606
607
608
609
610
611
      mparams: method hyper parameters.

    Returns:
      A Tensor with shape [batch_size] - the log-perplexity for each sequence.
    """
    mparams = self._mparams['sequence_loss_fn']
612
    with tf.compat.v1.variable_scope('sequence_loss_fn/SLF'):
613
614
      if mparams.label_smoothing > 0:
        smoothed_one_hot_labels = self.label_smoothing_regularization(
615
            chars_labels, mparams.label_smoothing)
616
617
618
619
620
621
622
623
624
625
626
627
        labels_list = tf.unstack(smoothed_one_hot_labels, axis=1)
      else:
        # NOTE: in case of sparse softmax we are not using one-hot
        # encoding.
        labels_list = tf.unstack(chars_labels, axis=1)

      batch_size, seq_length, _ = chars_logits.shape.as_list()
      if mparams.ignore_nulls:
        weights = tf.ones((batch_size, seq_length), dtype=tf.float32)
      else:
        # Suppose that reject character is the last in the charset.
        reject_char = tf.constant(
628
629
630
            self._params.num_char_classes - 1,
            shape=(batch_size, seq_length),
            dtype=tf.int64)
631
        known_char = tf.not_equal(chars_labels, reject_char)
632
        weights = tf.cast(known_char, dtype=tf.float32)
633
634
635
636

      logits_list = tf.unstack(chars_logits, axis=1)
      weights_list = tf.unstack(weights, axis=1)
      loss = tf.contrib.legacy_seq2seq.sequence_loss(
637
638
639
640
641
          logits_list,
          labels_list,
          weights_list,
          softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing),
          average_across_timesteps=mparams.average_across_timesteps)
642
      tf.compat.v1.losses.add_loss(loss)
643
644
645
646
647
648
649
650
      return loss

  def create_summaries(self, data, endpoints, charset, is_training):
    """Creates all summaries for the model.

    Args:
      data: InputEndpoints namedtuple.
      endpoints: OutputEndpoints namedtuple.
651
652
      charset: A dictionary with mapping between character codes and unicode
        characters. Use the one provided by a dataset.charset.
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
      is_training: If True will create summary prefixes for training job,
        otherwise - for evaluation.

    Returns:
      A list of evaluation ops
    """

    def sname(label):
      prefix = 'train' if is_training else 'eval'
      return '%s/%s' % (prefix, label)

    max_outputs = 4
    # TODO(gorban): uncomment, when tf.summary.text released.
    # charset_mapper = CharsetMapper(charset)
    # pr_text = charset_mapper.get_text(
    #     endpoints.predicted_chars[:max_outputs,:])
    # tf.summary.text(sname('text/pr'), pr_text)
    # gt_text = charset_mapper.get_text(data.labels[:max_outputs,:])
    # tf.summary.text(sname('text/gt'), gt_text)
672
673
    tf.compat.v1.summary.image(
        sname('image'), data.images, max_outputs=max_outputs)
674
675

    if is_training:
676
      tf.compat.v1.summary.image(
677
          sname('image/orig'), data.images_orig, max_outputs=max_outputs)
678
679
      for var in tf.compat.v1.trainable_variables():
        tf.compat.v1.summary.histogram(var.op.name, var)
680
681
682
683
684
685
686
687
688
689
      return None

    else:
      names_to_values = {}
      names_to_updates = {}

      def use_metric(name, value_update_tuple):
        names_to_values[name] = value_update_tuple[0]
        names_to_updates[name] = value_update_tuple[1]

690
691
692
693
694
695
696
      use_metric(
          'CharacterAccuracy',
          metrics.char_accuracy(
              endpoints.predicted_chars,
              data.labels,
              streaming=True,
              rej_char=self._params.null_code))
697
      # Sequence accuracy computed by cutting sequence at the first null char
698
699
700
701
702
703
704
      use_metric(
          'SequenceAccuracy',
          metrics.sequence_accuracy(
              endpoints.predicted_chars,
              data.labels,
              streaming=True,
              rej_char=self._params.null_code))
705

706
      for name, value in names_to_values.items():
707
        summary_name = 'eval/' + name
708
709
        tf.compat.v1.summary.scalar(
            summary_name, tf.compat.v1.Print(value, [value], summary_name))
710
      return list(names_to_updates.values())
711

712
713
  def create_init_fn_to_restore(self,
                                master_checkpoint,
714
                                inception_checkpoint=None):
715
716
717
    """Creates an init operations to restore weights from various checkpoints.

    Args:
718
719
      master_checkpoint: path to a checkpoint which contains all weights for the
        whole model.
720
721
722
723
724
725
726
727
728
729
      inception_checkpoint: path to a checkpoint which contains weights for the
        inception part only.

    Returns:
      a function to run initialization ops.
    """
    all_assign_ops = []
    all_feed_dict = {}

    def assign_from_checkpoint(variables, checkpoint):
730
731
      logging.info('Request to re-store %d weights from %s', len(variables),
                   checkpoint)
732
733
734
735
736
737
738
      if not variables:
        logging.error('Can\'t find any variables to restore.')
        sys.exit(1)
      assign_op, feed_dict = slim.assign_from_checkpoint(checkpoint, variables)
      all_assign_ops.append(assign_op)
      all_feed_dict.update(feed_dict)

739
740
741
    logging.info('variables_to_restore:\n%s',
                 utils.variables_to_restore().keys())
    logging.info('moving_average_variables:\n%s',
742
                 [v.op.name for v in tf.compat.v1.moving_average_variables()])
743
    logging.info('trainable_variables:\n%s',
744
                 [v.op.name for v in tf.compat.v1.trainable_variables()])
745
746
747
748
749
    if master_checkpoint:
      assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint)

    if inception_checkpoint:
      variables = utils.variables_to_restore(
750
          'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True)
751
752
753
754
755
756
757
      assign_from_checkpoint(variables, inception_checkpoint)

    def init_assign_fn(sess):
      logging.info('Restoring checkpoint(s)')
      sess.run(all_assign_ops, all_feed_dict)

    return init_assign_fn