calibration_metrics.py 4.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Object detection calibration metrics.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

22
import tensorflow.compat.v1 as tf
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from tensorflow.python.ops import metrics_impl


def _safe_div(numerator, denominator):
  """Divides two tensors element-wise, returning 0 if the denominator is <= 0.

  Args:
    numerator: A real `Tensor`.
    denominator: A real `Tensor`, with dtype matching `numerator`.

  Returns:
    0 if `denominator` <= 0, else `numerator` / `denominator`
  """
  t = tf.truediv(numerator, denominator)
  zero = tf.zeros_like(t, dtype=denominator.dtype)
  condition = tf.greater(denominator, zero)
  zero = tf.cast(zero, t.dtype)
  return tf.where(condition, t, zero)


def _ece_from_bins(bin_counts, bin_true_sum, bin_preds_sum, name):
  """Calculates Expected Calibration Error from accumulated statistics."""
  bin_accuracies = _safe_div(bin_true_sum, bin_counts)
  bin_confidences = _safe_div(bin_preds_sum, bin_counts)
  abs_bin_errors = tf.abs(bin_accuracies - bin_confidences)
  bin_weights = _safe_div(bin_counts, tf.reduce_sum(bin_counts))
  return tf.reduce_sum(abs_bin_errors * bin_weights, name=name)


def expected_calibration_error(y_true, y_pred, nbins=20):
  """Calculates Expected Calibration Error (ECE).

  ECE is a scalar summary statistic of calibration error. It is the
  sample-weighted average of the difference between the predicted and true
  probabilities of a positive detection across uniformly-spaced model
  confidences [0, 1]. See referenced paper for a thorough explanation.

  Reference:
    Guo, et. al, "On Calibration of Modern Neural Networks"
    Page 2, Expected Calibration Error (ECE).
    https://arxiv.org/pdf/1706.04599.pdf

  This function creates three local variables, `bin_counts`, `bin_true_sum`, and
  `bin_preds_sum` that are used to compute ECE.  For estimation of the metric
  over a stream of data, the function creates an `update_op` operation that
  updates these variables and returns the ECE.

  Args:
    y_true: 1-D tf.int64 Tensor of binarized ground truth, corresponding to each
      prediction in y_pred.
    y_pred: 1-D tf.float32 tensor of model confidence scores in range
      [0.0, 1.0].
    nbins: int specifying the number of uniformly-spaced bins into which y_pred
      will be bucketed.

  Returns:
    value_op: A value metric op that returns ece.
    update_op: An operation that increments the `bin_counts`, `bin_true_sum`,
      and `bin_preds_sum` variables appropriately and whose value matches `ece`.

  Raises:
    InvalidArgumentError: if y_pred is not in [0.0, 1.0].
  """
  bin_counts = metrics_impl.metric_variable(
      [nbins], tf.float32, name='bin_counts')
  bin_true_sum = metrics_impl.metric_variable(
      [nbins], tf.float32, name='true_sum')
  bin_preds_sum = metrics_impl.metric_variable(
      [nbins], tf.float32, name='preds_sum')

  with tf.control_dependencies([
      tf.assert_greater_equal(y_pred, 0.0),
      tf.assert_less_equal(y_pred, 1.0),
  ]):
    bin_ids = tf.histogram_fixed_width_bins(y_pred, [0.0, 1.0], nbins=nbins)

  with tf.control_dependencies([bin_ids]):
    update_bin_counts_op = tf.assign_add(
101
102
        bin_counts, tf.cast(tf.bincount(bin_ids, minlength=nbins),
                            dtype=tf.float32))
103
104
    update_bin_true_sum_op = tf.assign_add(
        bin_true_sum,
105
106
        tf.cast(tf.bincount(bin_ids, weights=y_true, minlength=nbins),
                dtype=tf.float32))
107
108
    update_bin_preds_sum_op = tf.assign_add(
        bin_preds_sum,
109
110
        tf.cast(tf.bincount(bin_ids, weights=y_pred, minlength=nbins),
                dtype=tf.float32))
111
112
113
114
115
116
117
118

  ece_update_op = _ece_from_bins(
      update_bin_counts_op,
      update_bin_true_sum_op,
      update_bin_preds_sum_op,
      name='update_op')
  ece = _ece_from_bins(bin_counts, bin_true_sum, bin_preds_sum, name='value')
  return ece, ece_update_op