hist.py 7.71 KB
Newer Older
mibaumgartner's avatar
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import time
import numpy as np

from pathlib import Path
from loguru import logger
from typing import Sequence, List, Dict, Any, Tuple


import matplotlib.pyplot as plt

from nndet.evaluator import DetectionMetric


class PredictionHistogram(DetectionMetric):
    def __init__(self,
                 classes: Sequence[str], save_dir: Path,
                 iou_thresholds: Sequence[float] = (0.1, 0.5),
                 bins: int = 50):
        """
        Class to compute prediction histograms. (Note: this class does not
        provide any scalar metrics)

        Args:
            classes: name of each class (index needs to correspond to predicted class indices!)
            save_dir: directory where histograms are saved to
            iou_thresholds: IoU thresholds for which FROC is evaluated
            bins: number of bins of histogram
        """
        self.classes = classes
        self.save_dir = save_dir

        self.iou_thresholds = iou_thresholds
        self.bins = bins

    def get_iou_thresholds(self) -> Sequence[float]:
        """
        Return IoU thresholds needed for this metric in an numpy array

        Returns:
            Sequence[float]: IoU thresholds [M], M is the number of thresholds
        """
        return self.iou_thresholds

    def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple[
            Dict[str, float], Dict[str, Dict[str, Any]]]:
        """
        Plot class independent and per class histograms. For more info see
        `method``plot_hist`

        Args:
            Dict: results over dataset
        """
        self.plot_hist(results_list=results_list)
        for cls_idx, cls_str in enumerate(self.classes):
            # filter current class from list of results and put them into a dict with a single entry
            results_by_cls = [{0: r[cls_idx]} for r in results_list if cls_idx in r if cls_idx in r]
            self.plot_hist(results_by_cls, title_prefix=f"cl_{cls_str}_")
        return {}, {}

    def plot_hist(self, results_list: List[Dict[int, Dict[str, np.ndarray]]],
                  title_prefix: str = "") -> Tuple[
                    Dict[str, float], Dict[str, Dict[str, Any]]]:
        """
        Compute prediction histograms for multiple IoU values

        Args:
            results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
                per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
                `dtMatches`: matched detections [T, G], where T = number of thresholds, G = number of ground truth
                `gtMatches`: matched ground truth boxes [T, D], where T = number of thresholds,
                    D = number of detections
                `dtScores`: prediction scores [D] detection scores
                `gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
                    should be ignored
                `dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
            title_prefix: prefix for title of histogram plot

        Returns:
            Dict: empty
            Dict[Dict[str, Any]]: histogram informations
                `{IoU Value}`:
                    `tp_hist` (np.ndarray): histogram if true positives; false negatives @ score=0 [:attr:`self.bins`]
                    `fp_hist` (np.ndarray): false positive histogram [:attr:`self.bins`]
                    `true_positives` (int): number of true positives according to matching
                    `false_positives` (int): number of false_positives according to matching
                    `false_negatives` (int): number of false_negatives according to matching
        """
        num_images = len(results_list)
        results = [_r for r in results_list for _r in r.values()]

        if len(results) == 0:
            logger.warning(f"WARNING, no results found for froc computation")
            return {}, {}

        # r['dtMatches'] [T, R], where R = sum(all detections)
        dt_matches = np.concatenate([r['dtMatches'] for r in results], axis=1)
        dt_ignores = np.concatenate([r['dtIgnore'] for r in results], axis=1)
        dt_scores = np.concatenate([r['dtScores'] for r in results])
        gt_ignore = np.concatenate([r['gtIgnore'] for r in results])
        self.check_number_of_iou(dt_matches, dt_ignores)
        
        num_gt = np.count_nonzero(gt_ignore == 0)  # number of ground truth boxes (non ignored)
        if num_gt == 0:
            logger.error("No ground truth found! Returning nothing.")
            return {}, {}

        for iou_idx, iou_val in enumerate(self.iou_thresholds):
            # filter scores with ignores detections
            _scores = dt_scores[np.logical_not(dt_ignores[iou_idx])]
            assert len(_scores) == len(dt_matches[iou_idx])
            _ = self.compute_histogram_one_iou(\
                dt_matches[iou_idx], _scores, num_images, num_gt, iou_val, title_prefix)
        return {}, {}

    def compute_histogram_one_iou(self, dt_matches: np.ndarray, dt_scores: np.ndarray,
                                  num_images: int, num_gt: int, iou: float,
                                  title_prefix: str):
        """
        Plot prediction histogram
        
        Args:
            dt_matches (np.ndarray): binary array indicating which bounding
                boxes have a large enough overlap with gt;
                [R] where R is the number of predictions
            dt_scores (np.ndarray): prediction score for each bounding box;
                [R] where R is the number of predictions
            num_images (int): number of images
            num_gt (int): number of ground truth bounding boxes
            iou: IoU values which is currently evaluated
            title_prefix: prefix for title of histogram plot
        """
        num_matched = np.sum(dt_matches)
        false_negatives = num_gt - num_matched # false negatives
        true_positives = np.sum(dt_matches)
        false_positives = np.sum(dt_matches == 0)

        _dt_matches = np.concatenate([dt_matches, [1] * int(false_negatives)])
        _dt_scores = np.concatenate([dt_scores, [0] * int(false_negatives)])

        plt.figure()
        plt.yscale('log')
        if 0 in dt_matches:
            plt.hist(_dt_scores[_dt_matches == 0], bins=self.bins, range=(0., 1.), 
                    alpha=0.3, color='g', label='false pos.')
        if 1 in dt_matches:
            plt.hist(_dt_scores[_dt_matches == 1], bins=self.bins, range=(0., 1.),
                    alpha=0.3, color='b', label='true pos. (false neg. @ score=0)')
        plt.legend()
        title = title_prefix + (f"tp:{true_positives} fp:{false_positives} "
                                f"fn:{false_negatives} pos:{true_positives+false_negatives}")
        plt.title(title)
        plt.xlabel('confidence score')
        plt.ylabel('log n')

        if self.save_dir is not None:
            save_path = self.save_dir / (f"{title_prefix}pred_hist_IoU@{iou}".replace(".", "_") + ".png")
            logger.info(f"Saving {save_path}")
            plt.savefig(save_path)
        plt.close()
        return None