errorcounter.py 4.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Some simple tools for error counting.

"""
import collections

# Named tuple Error counts describes the counts needed to accumulate errors
# over multiple trials:
#   false negatives (aka drops or deletions),
#   false positives: (aka adds or insertions),
#   truth_count: number of elements in ground truth = denominator for fn,
#   test_count:  number of elements in test string = denominator for fp,
# Note that recall = 1 - fn/truth_count, precision = 1 - fp/test_count,
# accuracy = 1 - (fn + fp) / (truth_count + test_count).
ErrorCounts = collections.namedtuple('ErrorCounts', ['fn', 'fp', 'truth_count',
                                                     'test_count'])

# Named tuple for error rates, as a percentage. Accuracies are just 100-error.
ErrorRates = collections.namedtuple('ErrorRates',
                                    ['label_error', 'word_recall_error',
                                     'word_precision_error', 'sequence_error'])


def CountWordErrors(ocr_text, truth_text):
  """Counts the word drop and add errors as a bag of words.

  Args:
    ocr_text:    OCR text string.
    truth_text:  Truth text string.

  Returns:
    ErrorCounts named tuple.
  """
  # Convert to lists of words.
  return CountErrors(ocr_text.split(), truth_text.split())


def CountErrors(ocr_text, truth_text):
  """Counts the drops and adds between 2 bags of iterables.

  Simple bag of objects count returns the number of dropped and added
  elements, regardless of order, from anything that is iterable, eg
  a pair of strings gives character errors, and a pair of word lists give
  word errors.
  Args:
    ocr_text:    OCR text iterable (eg string for chars, word list for words).
    truth_text:  Truth text iterable.

  Returns:
    ErrorCounts named tuple.
  """
  counts = collections.Counter(truth_text)
  counts.subtract(ocr_text)
  drops = sum(c for c in counts.values() if c > 0)
  adds = sum(-c for c in counts.values() if c < 0)
  return ErrorCounts(drops, adds, len(truth_text), len(ocr_text))


def AddErrors(counts1, counts2):
  """Adds the counts and returns a new sum tuple.

  Args:
    counts1: ErrorCounts named tuples to sum.
    counts2: ErrorCounts named tuples to sum.
  Returns:
    Sum of counts1, counts2.
  """
  return ErrorCounts(counts1.fn + counts2.fn, counts1.fp + counts2.fp,
                     counts1.truth_count + counts2.truth_count,
                     counts1.test_count + counts2.test_count)


def ComputeErrorRates(label_counts, word_counts, seq_errors, num_seqs):
  """Returns an ErrorRates corresponding to the given counts.

  Args:
    label_counts: ErrorCounts for the character labels
    word_counts:  ErrorCounts for the words
    seq_errors:   Number of sequence errors
    num_seqs:     Total sequences
  Returns:
    ErrorRates corresponding to the given counts.
  """
  label_errors = label_counts.fn + label_counts.fp
  num_labels = label_counts.truth_count + label_counts.test_count
  return ErrorRates(
      ComputeErrorRate(label_errors, num_labels),
      ComputeErrorRate(word_counts.fn, word_counts.truth_count),
      ComputeErrorRate(word_counts.fp, word_counts.test_count),
      ComputeErrorRate(seq_errors, num_seqs))


def ComputeErrorRate(error_count, truth_count):
  """Returns a sanitized percent error rate from the raw counts.

  Prevents div by 0 and clips return to 100%.
  Args:
    error_count: Number of errors.
    truth_count: Number to divide by.

  Returns:
    100.0 * error_count / truth_count clipped to 100.
  """
  if truth_count == 0:
    truth_count = 1
    error_count = 1
  elif error_count > truth_count:
    error_count = truth_count
  return error_count * 100.0 / truth_count