abstract.py 3.86 KB
Newer Older
mibaumgartner's avatar
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from abc import abstractmethod, ABC

import numpy as np
from typing import Dict, List, Sequence


__all__ = ["AbstractEvaluator", "AbstractMetric", "DetectionMetric"]


class AbstractEvaluator(ABC):
    @abstractmethod
    def run_online_evaluation(self, *args, **kwargs):
        """
        Compute necessary values per batch for later evaluation
        """
        raise NotImplementedError

    @abstractmethod
    def finish_online_evaluation(self, *args, **kwargs):
        """
        Accumulate results from batches and compute metrics
        """
        raise NotImplementedError

    @abstractmethod
    def reset(self):
        """
        Reset internal state of evaluator
        """
        raise NotImplementedError


class AbstractMetric(ABC):
    def __call__(self, *args, **kwargs) -> (Dict[str, float], Dict[str, np.ndarray]):
        """
        Compute metric. See :func:`compute` for more information.

        Args:
            *args: positional arguments passed to :func:`compute`
            **kwargs: keyword arguments passed to :func:`compute`

        Returns:
            Dict[str, float]: dictionary with scalar values for evaluation
            Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
        """
        return self.compute(*args, **kwargs)

    @abstractmethod
    def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> (
            Dict[str, float], Dict[str, np.ndarray]):
        """
        Compute metric

        Args:
            results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
                per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
                `dtMatches`: matched detections [T, G], where T = number of thresholds, G = number of ground truth
                `gtMatches`: matched ground truth boxes [T, D], where T = number of thresholds,
                    D = number of detections
                `dtScores`: prediction scores [D] detection scores
                `gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
                    should be ignored
                `dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored

        Returns:
            Dict[str, float]: dictionary with scalar values for evaluation
            Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
        """
        raise NotImplementedError


class DetectionMetric(AbstractMetric):
    @abstractmethod
    def get_iou_thresholds(self) -> Sequence[float]:
        """
        Return IoU thresholds needed for this metric in an numpy array

        Returns:
            Sequence[float]: IoU thresholds; [M], M is the number of thresholds
        """
        raise NotImplementedError
    
    def check_number_of_iou(self, *args) -> None:
        """
        Check if shape of input in first dimension is consistent with expected IoU values
        (assumes IoU dimension is the first dimension)

        Args:
            args: array like inputs with shape function
        """
        num_ious = len(self.get_iou_thresholds())
        for arg in args:
            assert arg.shape[0] == num_ious