Commit 3e94607a authored by mibaumgartner's avatar mibaumgartner
Browse files

evaluator

parent a28153dd
from nndet.evaluator.abstract import AbstractMetric, AbstractEvaluator, DetectionMetric
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import abstractmethod, ABC
import numpy as np
from typing import Dict, List, Sequence
__all__ = ["AbstractEvaluator", "AbstractMetric", "DetectionMetric"]
class AbstractEvaluator(ABC):
@abstractmethod
def run_online_evaluation(self, *args, **kwargs):
"""
Compute necessary values per batch for later evaluation
"""
raise NotImplementedError
@abstractmethod
def finish_online_evaluation(self, *args, **kwargs):
"""
Accumulate results from batches and compute metrics
"""
raise NotImplementedError
@abstractmethod
def reset(self):
"""
Reset internal state of evaluator
"""
raise NotImplementedError
class AbstractMetric(ABC):
def __call__(self, *args, **kwargs) -> (Dict[str, float], Dict[str, np.ndarray]):
"""
Compute metric. See :func:`compute` for more information.
Args:
*args: positional arguments passed to :func:`compute`
**kwargs: keyword arguments passed to :func:`compute`
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
"""
return self.compute(*args, **kwargs)
@abstractmethod
def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> (
Dict[str, float], Dict[str, np.ndarray]):
"""
Compute metric
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, G], where T = number of thresholds, G = number of ground truth
`gtMatches`: matched ground truth boxes [T, D], where T = number of thresholds,
D = number of detections
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
"""
raise NotImplementedError
class DetectionMetric(AbstractMetric):
@abstractmethod
def get_iou_thresholds(self) -> Sequence[float]:
"""
Return IoU thresholds needed for this metric in an numpy array
Returns:
Sequence[float]: IoU thresholds; [M], M is the number of thresholds
"""
raise NotImplementedError
def check_number_of_iou(self, *args) -> None:
"""
Check if shape of input in first dimension is consistent with expected IoU values
(assumes IoU dimension is the first dimension)
Args:
args: array like inputs with shape function
"""
num_ious = len(self.get_iou_thresholds())
for arg in args:
assert arg.shape[0] == num_ious
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import defaultdict
from typing import Dict, Sequence, Callable, Tuple, Union, Mapping
import numpy as np
from loguru import logger
from sklearn.metrics import accuracy_score, average_precision_score, confusion_matrix, \
f1_score, precision_score, recall_score, roc_auc_score
from nndet.evaluator import AbstractEvaluator
__all__ = ["CaseEvaluator"]
class _CaseEvaluator(AbstractEvaluator):
def __init__(self,
classes: Sequence[Union[str, int]],
score_metrics_scalar: Mapping[str, Callable] = None,
class_metrics_scalar: Mapping[str, Callable] = None,
score_metrics_curve: Mapping[str, Callable] = None,
class_metrics_curve: Mapping[str, Callable] = None,
target_class: int = None,
):
"""
Compute case level evaluation metrics
Predictions for individual instances are aggregated by using the
max of the predicted score for each class. Final class prediction
is computed by an argmax over that scores. The mappings of the
metrics are later used as the keys of the result dict.
Args:
classes: class present in whole dataset
score_metrics_scalar: metrics which accept ground truth classes [N]
and prediction scores [N, C] for evaluation; N is the nunber
of cases and C is the number of classes. The output should
be a scalar.
class_metrics_scalar: metrics which accept ground truth classes [N]
and prediction classes [N] for evaluation; N is the nunber
of cases and C is the number of classes. The output should
be a scalar.
score_metrics_curve: metrics which accept ground truth classes [N]
and prediction scores [N, C] for evaluation; N is the nunber
of cases and C is the number of classes. The output should be
an array like object.
class_metrics_curve: metrics which accept ground truth classes [N]
and prediction classes [N] for evaluation; N is the nunber
of cases and C is the number of classes. The output should
be an array like object.
target_class: target class for case evaluation (internally
results are evaluated in a binary case target class vs rest).
If None, fall back to fg vs bg
"""
self.results_list = defaultdict(list)
self.score_metrics_scalar = score_metrics_scalar if score_metrics_scalar is not None else {}
self.class_metrics_scalar = class_metrics_scalar if class_metrics_scalar is not None else {}
self.score_metrics_curve = score_metrics_curve if score_metrics_curve is not None else {}
self.class_metrics_curve = class_metrics_curve if class_metrics_curve is not None else {}
self.target_class = target_class
self.classes = classes
self.num_classes = len(classes)
def reset(self):
"""
Reset internal state for new epoch
"""
self.results_list = defaultdict(list)
def run_online_evaluation(self,
pred_classes: Sequence[np.ndarray],
pred_scores: Sequence[np.ndarray],
gt_classes: Sequence[np.ndarray],
) -> Dict:
"""
Run evaluation on each case (accepts a batch of case resutls
at once).
Args:
pred_classes (Sequence[np.ndarray]): predicted classes from a batch
of cases; List[[D]], D number of predictions
pred_scores (Sequence[np.ndarray]): predicted score for each
bounding box; List[[D]], D number of predictions
gt_classes (Sequence[np.ndarray]): ground truth classes for each
instance in a case; List[[G]], G number of ground truth
Returns:
Dict: empty dict
Notes:
This caches the max predicted probability per class per element
and the unique classes present per element.
"""
case_classes = [np.unique(gtc) for gtc in gt_classes]
case_scores = []
for case_instance_scores, case_instance_classes in zip(pred_scores, pred_classes):
_scores = np.zeros(self.num_classes)
for instance_score, instance_class in zip(case_instance_scores, case_instance_classes):
if _scores[int(instance_class)] < instance_score:
_scores[int(instance_class)] = instance_score
case_scores.append(_scores)
self.results_list["case_classes"].extend(case_classes)
self.results_list["case_scores"].extend(case_scores)
return {}
def finish_online_evaluation(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Compute final scores and curves of metrics
Returns:
Dict: results of scalar metrics
Dict: results of curve metrics
"""
# aggregate cases
gt_classes = self.aggregate_classes()
pred_scores, pred_classes = self.aggregate_prdictions()
# compute metrics
curve_results = {}
for key, metric in self.score_metrics_curve.items():
curve_results[key] = metric(gt_classes, pred_scores)
for key, metric in self.class_metrics_curve.items():
curve_results[key] = metric(gt_classes, pred_classes)
scalar_results = {}
for key, metric in self.score_metrics_scalar.items():
try:
scalar_results[key] = metric(gt_classes, pred_scores)
except (ValueError, RuntimeError) as e:
logger.warning(f"Metric {key} exited with error {e}; writing nan to result")
scalar_results[key] = np.nan
for key, metric in self.class_metrics_scalar.items():
try:
scalar_results[key] = metric(gt_classes, pred_classes)
except (ValueError, RuntimeError) as e:
logger.warning(f"Metric {key} exited with error {e}; writing nan to result")
scalar_results[key] = np.nan
return scalar_results, curve_results
def aggregate_classes(self) -> np.ndarray:
"""
Aggregate classes of each instance in a case to one case class
Returns:
np.ndarray: class per case [N], where N is the number of cases
"""
if self.target_class is not None:
gt_classes = np.asarray(
[int(self.target_class in cc) for cc in self.results_list["case_classes"]])
else:
gt_classes = np.asarray(
[1 if len(cc) > 0 else 0 for cc in self.results_list["case_classes"]])
return gt_classes
def aggregate_prdictions(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Aggreagte prediction scores per class to case scores with target class
Returns:
np.ndarray: predicted scores
np.ndarray: predicted classes
"""
_pred_scores = np.stack(self.results_list["case_scores"], axis=0) # N, num_classes
if self.target_class is not None:
pred_scores = _pred_scores[:, self.target_class] # N
# pred_classes = (np.argmax(_pred_scores, axis=1) == self.target_class).astype(np.int32) # N
# This is not always the correct choice, depending on the final
# nonlinearity of the network (sigmoid vs. softmax)
pred_classes = (pred_scores > 0.5).astype(np.int32) # N
else:
pred_scores = _pred_scores.max(axis=1) # N
pred_classes = (pred_scores > 0.5).astype(np.int32) # N
return pred_scores, pred_classes
class CaseEvaluator(_CaseEvaluator):
@classmethod
def create(cls,
classes: Sequence[str],
target_class: int = None
):
"""
Evaluation on patient level
Args:
classes: classes present in dataset
target_class: if multiple classes are given, define
a target class to evaluate in an target_class vs rest setting.
Defaults to None.
Returns:
CaseEvaluator: evaluator
"""
# if len(classes) > 2 and target_class is None:
# f1_fn = partial(f1_score, average="macro")
# prec_fn = partial(precision_score, average="macro")
# rec_fn = partial(recall_score, average="macro")
# else:
f1_fn = f1_score
prec_fn = precision_score
rec_fn = recall_score
score_metrics_scalar = {"auc_case": roc_auc_score, "ap_case": average_precision_score}
class_metrics_scalar = {"f1_case": f1_fn, "prec_case": prec_fn,
"rec_case": rec_fn, "acc_case": accuracy_score}
score_metrics_curve = {}
class_metrics_curve = {"cfm_case": confusion_matrix}
return cls(classes=classes,
score_metrics_scalar=score_metrics_scalar,
class_metrics_scalar=class_metrics_scalar,
score_metrics_curve=score_metrics_curve,
class_metrics_curve=class_metrics_curve,
target_class=target_class,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from pathlib import Path
from functools import partial
from typing import Optional, Sequence, Callable, Dict, List, Tuple
import numpy as np
from nndet.evaluator.abstract import AbstractEvaluator, DetectionMetric
from nndet.evaluator.detection.matching import matching_batch
from nndet.detection.boxes import box_iou_np
from nndet.evaluator.detection.coco import COCOMetric
from nndet.evaluator.detection.froc import FROCMetric
from nndet.evaluator.detection.hist import PredictionHistogram
__all__ = ["DetectionEvaluator"]
class DetectionEvaluator(AbstractEvaluator):
def __init__(self,
metrics: Sequence[DetectionMetric],
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = box_iou_np,
max_detections: int = 100,
):
"""
Class for evaluate detection metrics
Args:
metrics (Sequence[DetectionMetric]: detection metrics to evaluate
iou_fn (Callable[[np.ndarray, np.ndarray], np.ndarray]): compute overlap for each pair
max_detections (int): number of maximum detections per image (reduces computation)
"""
self.iou_fn = iou_fn
self.max_detections = max_detections
self.metrics = metrics
self.results_list = [] # store results of each image
self.iou_thresholds = self.get_unique_iou_thresholds()
self.iou_mapping = self.get_indices_of_iou_for_each_metric()
def get_unique_iou_thresholds(self):
"""
Compute unique set of iou thresholds
"""
iou_thresholds = [_i for i in self.metrics for _i in i.get_iou_thresholds()]
iou_thresholds = list(set(iou_thresholds))
iou_thresholds.sort()
return iou_thresholds
def get_indices_of_iou_for_each_metric(self):
"""
Find indices of iou thresholds for each metric
"""
return [[self.iou_thresholds.index(th) for th in m.get_iou_thresholds()]
for m in self.metrics]
def run_online_evaluation(self,
pred_boxes: Sequence[np.ndarray],
pred_classes: Sequence[np.ndarray],
pred_scores: Sequence[np.ndarray],
gt_boxes: Sequence[np.ndarray],
gt_classes: Sequence[np.ndarray],
gt_ignore: Sequence[Sequence[bool]] = None) -> Dict:
"""
Preprocess batch results for final evaluation
Args:
pred_boxes (Sequence[np.ndarray]): predicted boxes from single batch; List[[D, dim * 2]], D number of
predictions
pred_classes (Sequence[np.ndarray]): predicted classes from a single batch; List[[D]], D number of
predictions
pred_scores (Sequence[np.ndarray]): predicted score for each bounding box; List[[D]], D number of
predictions
gt_boxes (Sequence[np.ndarray]): ground truth boxes; List[[G, dim * 2]], G number of ground truth
gt_classes (Sequence[np.ndarray]): ground truth classes; List[[G]], G number of ground truth
gt_ignore (Sequence[Sequence[bool]]): specified if which ground truth boxes are not counted as true
positives (detections which match theses boxes are not counted as false positives either);
List[[G]], G number of ground truth
Returns
dict: empty dict... detection metrics can only be evaluated at the end
"""
if gt_ignore is None:
gt_ignore = [np.zeros(gt_boxes_img.shape[0]).reshape(-1) for gt_boxes_img in gt_boxes]
self.results_list.extend(matching_batch(
self.iou_fn, self.iou_thresholds, pred_boxes=pred_boxes, pred_classes=pred_classes,
pred_scores=pred_scores, gt_boxes=gt_boxes, gt_classes=gt_classes, gt_ignore=gt_ignore,
max_detections=self.max_detections))
return {}
def finish_online_evaluation(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Accumulate results of individual batches and compute final metrics
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
"""
metric_scores = {}
metric_curves = {}
for metric_idx, metric in enumerate(self.metrics):
_filter = partial(self.iou_filter, iou_idx=self.iou_mapping[metric_idx])
iou_filtered_results = list(map(_filter, self.results_list))
score, curve = metric(iou_filtered_results)
if score is not None:
metric_scores.update(score)
if curve is not None:
metric_curves.update(curve)
return metric_scores, metric_curves
@staticmethod
def iou_filter(image_dict: Dict[int, Dict[str, np.ndarray]], iou_idx: List[int],
filter_keys: Sequence[str] = ('dtMatches', 'gtMatches', 'dtIgnore')):
"""
This functions can be used to filter specific IoU values from the results
to make sure that the correct IoUs are passed to metric
Parameters
----------
image_dict : dict
dictionary containin :param:`filter_keys` which contains IoUs in the first dimension
iou_idx : List[int]
indices of IoU values to filter from keys
filter_keys : tuple, optional
keys to filter, by default ('dtMatches', 'gtMatches', 'dtIgnore')
Returns
-------
dict
filtered dictionary
"""
iou_idx = list(iou_idx)
filtered = {}
for cls_key, cls_item in image_dict.items():
filtered[cls_key] = {key: item[iou_idx] if key in filter_keys else item
for key, item in cls_item.items()}
return filtered
def reset(self):
"""
Reset internal state of evaluator
"""
self.results_list = []
class BoxEvaluator(DetectionEvaluator):
@classmethod
def create(cls,
classes: Sequence[str],
fast: bool = True,
verbose: bool = False,
save_dir: Optional[Path] = None,
):
"""
Create an box evaluator object
Args:
classes: classes present in the dataset
fast: Reduces the evaluation suite to save time.
Only evaluated IoUs in the range of 0.1-0.5
Does no calculate pre class metrics
verbose: Additional logging output
save_dir: Path to save information
Returns:
BoxEvaluator: evaluator to efficiently compute metrics
"""
iou_fn = box_iou_np
iou_range = (0.1, 0.5, 0.05)
iou_thresholds = (0.1, 0.5) if fast else np.arange(0.1, 1.0, 0.1)
per_class = False if fast else True
metrics = []
metrics.append(
FROCMetric(classes,
iou_thresholds=iou_thresholds,
fpi_thresholds=(1/8, 1/4, 1/2, 1, 2, 4, 8),
per_class=per_class,
verbose=verbose,
save_dir= None if fast else save_dir
)
)
metrics.append(
COCOMetric(classes,
iou_list=iou_thresholds,
iou_range=iou_range,
max_detection=(100, ),
per_class=per_class,
verbose=verbose,
)
)
if not fast:
metrics.append(
PredictionHistogram(classes=classes,
save_dir=save_dir,
iou_thresholds=(0.1, 0.5),
)
)
return cls(metrics=tuple(metrics), iou_fn=iou_fn)
from nndet.evaluator.detection.froc import FROCMetric
from nndet.evaluator.detection.coco import COCOMetric
from nndet.evaluator.detection.hist import PredictionHistogram
import time
import numpy as np
from loguru import logger
from typing import Sequence, List, Dict, Union, Tuple
from nndet.evaluator import DetectionMetric
class COCOMetric(DetectionMetric):
def __init__(self,
classes: Sequence[str],
iou_list: Sequence[float] = (0.1, 0.5, 0.75),
iou_range: Sequence[float] = (0.1, 0.5, 0.05),
max_detection: Sequence[int] = (1, 5, 100),
per_class: bool = True,
verbose: bool = True):
"""
Class to compute COCO metrics
Metrics computed:
mAP over the IoU range specified by :param:`iou_range` at last value of :param:`max_detection`
AP values at IoU thresholds specified by :param:`iou_list` at last value of :param:`max_detection`
AR over max detections thresholds defined by :param:`max_detection` (over iou range)
Args:
classes (Sequence[str]): name of each class (index needs to correspond to predicted class indices!)
iou_list (Sequence[float]): specific thresholds where ap is evaluated and saved
iou_range (Sequence[float]): (start, stop, step) for mAP iou thresholds
max_detection (Sequence[int]): maximum number of detections per image
verbose (bool): log time needed for evaluation
"""
self.verbose = verbose
self.classes = classes
self.per_class = per_class
iou_list = np.array(iou_list)
_iou_range = np.linspace(iou_range[0], iou_range[1],
int(np.round((iou_range[1] - iou_range[0]) / iou_range[2])) + 1, endpoint=True)
self.iou_thresholds = np.union1d(iou_list, _iou_range)
self.iou_range = iou_range
# get indices of iou values of ious range and ious list for later evaluation
self.iou_list_idx = np.nonzero(iou_list[:, np.newaxis] == self.iou_thresholds[np.newaxis])[1]
self.iou_range_idx = np.nonzero(_iou_range[:, np.newaxis] == self.iou_thresholds[np.newaxis])[1]
assert (self.iou_thresholds[self.iou_list_idx] == iou_list).all()
assert (self.iou_thresholds[self.iou_range_idx] == _iou_range).all()
self.recall_thresholds = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
self.max_detections = max_detection
def get_iou_thresholds(self) -> Sequence[float]:
"""
Return IoU thresholds needed for this metric in an numpy array
Returns:
Sequence[float]: IoU thresholds [M], M is the number of thresholds
"""
return self.iou_thresholds
def compute(self,
results_list: List[Dict[int, Dict[str, np.ndarray]]],
) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Compute COCO metrics
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number of thresholds, G = number of
ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
Returns:
Dict[str, float]: dictionary with coco metrics
Dict[str, np.ndarray]: None
"""
if self.verbose:
logger.info('Start COCO metric computation...')
tic = time.time()
dataset_statistics = self.compute_statistics(results_list=results_list)
if self.verbose:
toc = time.time()
logger.info(f'Statistics for COCO metrics finished (t={(toc - tic):0.2f}s).')
results = {}
results.update(self.compute_ap(dataset_statistics))
results.update(self.compute_ar(dataset_statistics))
if self.verbose:
toc = time.time()
logger.info(f'COCO metrics computed in t={(toc - tic):0.2f}s.')
return results, None
def compute_ap(self, dataset_statistics: dict) -> dict:
"""
Compute AP metrics
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number of thresholds, G = number of
ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
"""
results = {}
if self.iou_range: # mAP
key = (f"mAP_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_"
f"MaxDet_{self.max_detections[-1]}")
results[key] = self.select_ap(dataset_statistics, iou_idx=self.iou_range_idx, max_det_idx=-1)
if self.per_class:
for cls_idx, cls_str in enumerate(self.classes): # per class results
key = (f"{cls_str}_"
f"mAP_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_"
f"MaxDet_{self.max_detections[-1]}")
results[key] = self.select_ap(dataset_statistics, iou_idx=self.iou_range_idx,
cls_idx=cls_idx, max_det_idx=-1)
for idx in self.iou_list_idx: # AP@IoU
key = f"AP_IoU_{self.iou_thresholds[idx]:.2f}_MaxDet_{self.max_detections[-1]}"
results[key] = self.select_ap(dataset_statistics, iou_idx=[idx], max_det_idx=-1)
if self.per_class:
for cls_idx, cls_str in enumerate(self.classes): # per class results
key = (f"{cls_str}_"
f"AP_IoU_{self.iou_thresholds[idx]:.2f}_"
f"MaxDet_{self.max_detections[-1]}")
results[key] = self.select_ap(dataset_statistics,
iou_idx=[idx], cls_idx=cls_idx, max_det_idx=-1)
return results
def compute_ar(self, dataset_statistics: dict) -> dict:
"""
Compute AR metrics
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number of thresholds, G = number of
ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
"""
results = {}
for max_det_idx, max_det in enumerate(self.max_detections): # mAR
key = f"mAR_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_MaxDet_{max_det}"
results[key] = self.select_ar(dataset_statistics, max_det_idx=max_det_idx)
if self.per_class:
for cls_idx, cls_str in enumerate(self.classes): # per class results
key = (f"{cls_str}_"
f"mAR_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_"
f"MaxDet_{max_det}")
results[key] = self.select_ar(dataset_statistics,
cls_idx=cls_idx, max_det_idx=max_det_idx)
for idx in self.iou_list_idx: # AR@IoU
key = f"AR_IoU_{self.iou_thresholds[idx]:.2f}_MaxDet_{self.max_detections[-1]}"
results[key] = self.select_ar(dataset_statistics, iou_idx=idx, max_det_idx=-1)
if self.per_class:
for cls_idx, cls_str in enumerate(self.classes): # per class results
key = (f"{cls_str}_"
f"AR_IoU_{self.iou_thresholds[idx]:.2f}_"
f"MaxDet_{self.max_detections[-1]}")
results[key] = self.select_ar(dataset_statistics, iou_idx=idx,
cls_idx=cls_idx, max_det_idx=-1)
return results
@staticmethod
def select_ap(dataset_statistics: dict, iou_idx: Union[int, List[int]] = None,
cls_idx: Union[int, Sequence[int]] = None, max_det_idx: int = -1) -> np.ndarray:
"""
Compute average precision
Args:
dataset_statistics (dict): computed statistics over dataset
`counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
detection thresholds
`recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
`precision`: Precision values at specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
`scores`: Scores corresponding to specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
iou_idx: index of IoU values to select for evaluation(if None, all values are used)
cls_idx: class indices to select, if None all classes will be selected
max_det_idx (int): index to select max detection threshold from data
Returns:
np.ndarray: AP value
"""
prec = dataset_statistics["precision"]
if iou_idx is not None:
prec = prec[iou_idx]
if cls_idx is not None:
prec = prec[..., cls_idx, :]
prec = prec[..., max_det_idx]
return np.mean(prec)
@staticmethod
def select_ar(dataset_statistics: dict, iou_idx: Union[int, Sequence[int]] = None,
cls_idx: Union[int, Sequence[int]] = None,
max_det_idx: int = -1) -> np.ndarray:
"""
Compute average recall
Args:
dataset_statistics (dict): computed statistics over dataset
`counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
detection thresholds
`recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
`precision`: Precision values at specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
`scores`: Scores corresponding to specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
iou_idx: index of IoU values to select for evaluation(if None, all values are used)
cls_idx: class indices to select, if None all classes will be selected
max_det_idx (int): index to select max detection threshold from data
Returns:
np.ndarray: recall value
"""
rec = dataset_statistics["recall"]
if iou_idx is not None:
rec = rec[iou_idx]
if cls_idx is not None:
rec = rec[..., cls_idx, :]
rec = rec[..., max_det_idx]
if len(rec[rec > -1]) == 0:
rec = -1
else:
rec = np.mean(rec[rec > -1])
return rec
def compute_statistics(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]
) -> Dict[str, Union[np.ndarray, List]]:
"""
Compute statistics needed for COCO metrics (mAP, AP of individual classes, mAP@IoU_Thresholds, AR)
Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per cateory (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number of thresholds, G = number of
ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth should be
ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
Returns:
dict: computed statistics over dataset
`counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
detection thresholds
`recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
`precision`: Precision values at specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
`scores`: Scores corresponding to specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
"""
num_iou_th = len(self.iou_thresholds)
num_recall_th = len(self.recall_thresholds)
num_classes = len(self.classes)
num_max_detections = len(self.max_detections)
# -1 for the precision of absent categories
precision = -np.ones((num_iou_th, num_recall_th, num_classes, num_max_detections))
recall = -np.ones((num_iou_th, num_classes, num_max_detections))
scores = -np.ones((num_iou_th, num_recall_th, num_classes, num_max_detections))
for cls_idx, cls_i in enumerate(self.classes): # for each class
for maxDet_idx, maxDet in enumerate(self.max_detections): # for each maximum number of detections
results = [r[cls_idx] for r in results_list if cls_idx in r]
if len(results) == 0:
logger.warning(f"WARNING, no results found for coco metric for class {cls_i}")
continue
dt_scores = np.concatenate([r['dtScores'][0:maxDet] for r in results])
# different sorting method generates slightly different results.
# mergesort is used to be consistent as Matlab implementation.
inds = np.argsort(-dt_scores, kind='mergesort')
dt_scores_sorted = dt_scores[inds]
# r['dtMatches'] [T, R], where R = sum(all detections)
dt_matches = np.concatenate([r['dtMatches'][:, 0:maxDet] for r in results], axis=1)[:, inds]
dt_ignores = np.concatenate([r['dtIgnore'][:, 0:maxDet] for r in results], axis=1)[:, inds]
self.check_number_of_iou(dt_matches, dt_ignores)
gt_ignore = np.concatenate([r['gtIgnore'] for r in results])
num_gt = np.count_nonzero(gt_ignore == 0) # number of ground truth boxes (non ignored)
if num_gt == 0:
logger.warning(f"WARNING, no gt found for coco metric for class {cls_i}")
continue
# ignore cases need to be handled differently for tp and fp
tps = np.logical_and(dt_matches, np.logical_not(dt_ignores))
fps = np.logical_and(np.logical_not(dt_matches), np.logical_not(dt_ignores))
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float32)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float32)
for th_ind, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): # for each threshold th_ind
tp, fp = np.array(tp), np.array(fp)
r, p, s = compute_stats_single_threshold(tp, fp, dt_scores_sorted, self.recall_thresholds, num_gt)
recall[th_ind, cls_idx, maxDet_idx] = r
precision[th_ind, :, cls_idx, maxDet_idx] = p
# corresponding score thresholds for recall steps
scores[th_ind, :, cls_idx, maxDet_idx] = s
return {
'counts': [num_iou_th, num_recall_th, num_classes, num_max_detections], # [4]
'recall': recall, # [num_iou_th, num_classes, num_max_detections]
'precision': precision, # [num_iou_th, num_recall_th, num_classes, num_max_detections]
'scores': scores, # [num_iou_th, num_recall_th, num_classes, num_max_detections]
}
def compute_stats_single_threshold(tp: np.ndarray, fp: np.ndarray, dt_scores_sorted: np.ndarray,
recall_thresholds: Sequence[float], num_gt: int) -> Tuple[
float, np.ndarray, np.ndarray]:
"""
Compute recall value, precision curve and scores thresholds
Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
Args:
tp (np.ndarray): cumsum over true positives [R], R is the number of detections
fp (np.ndarray): cumsum over false positives [R], R is the number of detections
dt_scores_sorted (np.ndarray): sorted (descending) scores [R], R is the number of detections
recall_thresholds (Sequence[float]): recall thresholds which should be evaluated
num_gt (int): number of ground truth bounding boxes (excluding boxes which are ignored)
Returns:
float: overall recall for given IoU value
np.ndarray: precision values at defined recall values
[RTH], where RTH is the number of recall thresholds
np.ndarray: prediction scores corresponding to recall values
[RTH], where RTH is the number of recall thresholds
"""
num_recall_th = len(recall_thresholds)
rc = tp / num_gt
# np.spacing(1) is the smallest representable epsilon with float
pr = tp / (fp + tp + np.spacing(1))
if len(tp):
recall = rc[-1]
else:
# no prediction
recall = 0
# array where precision values nearest to given recall th are saved
precision = np.zeros((num_recall_th,))
# save scores for corresponding recall value in here
th_scores = np.zeros((num_recall_th,))
# numpy is slow without cython optimization for accessing elements
# use python array gets significant speed improvement
pr = pr.tolist(); precision = precision.tolist()
# smooth precision curve (create box shape)
for i in range(len(tp) - 1, 0, -1):
if pr[i] > pr[i-1]:
pr[i-1] = pr[i]
# get indices to nearest given recall threshold (nn interpolation!)
inds = np.searchsorted(rc, recall_thresholds, side='left')
try:
for save_idx, array_index in enumerate(inds):
precision[save_idx] = pr[array_index]
th_scores[save_idx] = dt_scores_sorted[array_index]
except:
pass
return recall, np.array(precision), np.array(th_scores)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import time
import numpy as np
from loguru import logger
from typing import Sequence, List, Dict, Optional, Union, Tuple
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from nndet.evaluator import DetectionMetric
from sklearn.metrics import roc_curve
from collections import defaultdict
class FROCMetric(DetectionMetric):
def __init__(self,
classes: Sequence[str],
iou_thresholds: Sequence[float] = (0.1, 0.5),
fpi_thresholds: Sequence[float] = (1/8, 1/4, 1/2, 1, 2, 4, 8),
per_class: bool = False, verbose: bool = True,
save_dir: Optional[Union[str, Path]] = None,
):
"""
Class to compute FROC
Args:
classes: name of each class
(index needs to correspond to predicted class indices!)
iou_thresholds: IoU thresholds for which FROC
is evaluated
fpi_thresholds: false positive per image
thresholds (curve is interpolated at these values, score is
the mean of the computed sens values at these positions)
per_class: additional FROC curves are computed per class
verbose: log time needed for evaluation
"""
self.classes = classes
self.iou_thresholds = iou_thresholds
self.fpi_thresholds = fpi_thresholds
self.per_class = per_class
self.verbose = verbose
if save_dir is None:
self.save_dir = save_dir
else:
self.save_dir = Path(save_dir)
def get_iou_thresholds(self) -> Sequence[float]:
"""
Return IoU thresholds needed for this metric in an numpy array
Returns:
Sequence[float]: IoU thresholds [M], M is the number of thresholds
"""
return self.iou_thresholds
def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple[
Dict[str, float], Dict[str, np.ndarray]]:
"""
Compute FROC
Args:
results_list: list with result s per image (in list)
per category (dict). Inner Dict contains multiple results
obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where
T = number of thresholds, G = number of ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
Returns:
Dict[str, float]: FROC score per IoU (key: FROC_score@IoU:{key:2f})
Dict[str, np.ndarray]: FROC curve computed at specified fps
thresholds per IoU; [R] R is the number of fps thresholds
(key: FROC_curve@IoU:{key:2f})
"""
if self.verbose:
logger.info('Start FROC metric computation...')
tic = time.time()
scores = {}
curves = {}
_score, _curve = self.compute_froc_mul_iou(results_list)
scores.update(_score)
curves.update(_curve)
if self.verbose:
toc = time.time()
logger.info(f'FROC finished (t={(toc - tic):0.2f}s).')
if self.per_class:
_score, _curve = self.compute_froc_mul_iou_per_class(results_list)
scores.update(_score)
curves.update(_curve)
if self.verbose:
toc = time.time()
logger.info(f'FROC per class finished (t={(toc - tic):0.2f}s).')
if self.save_dir is not None:
self.plot_froc_curves(curves)
return scores, curves
def compute_froc_mul_iou(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple[
Dict[str, float], Dict[str, np.ndarray]]:
"""
Compute FROC curve for multiple IoU values
Args:
results_list: list with result s per image (in list)
per category (dict). Inner Dict contains multiple results
obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, G], where T = number of
thresholds, G = number of ground truth
`gtMatches`: matched ground truth boxes [T, D], where
T = number of thresholds, D = number of detections
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
Returns:
Dict[str, float]: FROC score per IoU
Dict[str,np.ndarray]: FROC curve computed at specified fps
thresholds per IoU; [R] R is the number of fps thresholds
"""
num_images = len(results_list)
results = [_r for r in results_list for _r in r.values()]
if len(results) == 0:
logger.warning(f"WARNING, no results found for froc computation")
return ({"froc_score": 0},
{"froc_curve": np.zeros(len(self.fpi_thresholds))})
# r['dtMatches'] [T, R], where R = sum(all detections)
dt_matches = np.concatenate([r['dtMatches'] for r in results], axis=1)
dt_ignores = np.concatenate([r['dtIgnore'] for r in results], axis=1)
dt_scores = np.concatenate([r['dtScores'] for r in results])
gt_ignore = np.concatenate([r['gtIgnore'] for r in results])
self.check_number_of_iou(dt_matches, dt_ignores)
num_gt = np.count_nonzero(gt_ignore == 0) # number of ground truth boxes (non ignored)
if num_gt == 0:
logger.error("No ground truth found! Returning 0 in FROC.")
return ({"froc_score": 0},
{"froc_curve": np.zeros(len(self.fpi_thresholds))})
# keep shape in case of 1 threshold
old_shape = dt_matches.shape
dt_matches = dt_matches[np.logical_not(dt_ignores)].reshape(old_shape)
curves = {}
for iou_idx, iou_val in enumerate(self.iou_thresholds):
# filter scores with ignores detections
_scores = dt_scores[np.logical_not(dt_ignores[iou_idx])]
assert len(_scores) == len(dt_matches[iou_idx])
_fps, _sens, _th = (self.compute_froc_curve_one_iou(
dt_matches[iou_idx], _scores, num_images, num_gt))
# interpolate at defined fpr thresholds
curves[iou_val] = np.interp(self.fpi_thresholds, _fps, _sens)
# linearly interpolate curves for needed fps values
scores = {f"FROC_score_IoU_{key:.2f}": np.mean(c) for key, c in curves.items()}
curves = {f"FROC_curve_IoU_{key:.2f}": c for key, c in curves.items()}
curves["FROC_fpi_thresholds"] = self.fpi_thresholds
return scores, curves
@staticmethod
def compute_froc_curve_one_iou(dt_matches: np.ndarray, dt_scores: np.ndarray,
num_images: int, num_gt: int):
"""
Compute FROC curve for a single IoU value
Args:
dt_matches (np.ndarray): binary array indicating which bounding
boxes have a large enough overlap with gt;
[R] where R is the number of predictions
dt_scores (np.ndarray): prediction score for each bounding box;
[R] where R is the number of predictions
num_images (int): number of images
num_gt (int): number of ground truth bounding boxes
Returns:
np.ndarray: false positives per image
np.ndarray: sensitivity
np.ndarray: thresholds
"""
num_detections = len(dt_matches)
num_matched = np.sum(dt_matches)
num_unmatched = num_detections - num_matched
if dt_matches.size == 0:
logger.warning("WARNING, no matches found.")
return np.zeros((2,)), np.zeros((2,)), np.zeros((2,))
else:
fpr, tpr, thresholds = roc_curve(dt_matches, dt_scores)
if num_unmatched == 0:
logger.warning("WARNING, no false positives found")
fps = np.zeros(len(fpr))
else:
fps = (fpr * num_unmatched) / num_images
sens = (tpr * num_matched) / num_gt
return fps, sens, thresholds
def compute_froc_mul_iou_per_class(
self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> (
Dict[str, float], Dict[str, np.ndarray]):
"""
Compute FROC curve for multiple classes
Args:
results_list: list with result s per image (in list)
per category (dict). Inner Dict contains multiple results
obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, G], where T = number of
thresholds, G = number of ground truth
`gtMatches`: matched ground truth boxes [T, D], where
T = number of thresholds, D = number of detections
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
Returns:
Dict[str, float]: FROC score computed per class per class
Dict[str, np.ndarray]: FROC curve computed per class per IoU;
[R] R is the number of fps thresholds
"""
froc_scores_cls = {}
froc_curves_cls = {}
for cls_idx, cls_str in enumerate(self.classes):
# filter current class from list of results and put them into a dict with a single entry
results_by_cls = [{0: r[cls_idx]} for r in results_list if cls_idx in r if cls_idx in r]
if results_by_cls:
cls_scores, cls_curves = self.compute_froc_mul_iou(results_by_cls)
froc_scores_cls.update({f"{cls_str}_{key}": item for key, item in cls_scores.items()})
froc_curves_cls.update({f"{cls_str}_{key}": item for key, item in cls_curves.items()})
return froc_scores_cls, froc_curves_cls
def plot_froc_curves(self, curves: Dict[str, Sequence[float]]) -> None:
"""
Plot frocs
Args:
curves: dict with froc curves (as obtained by :method:`compute`)
FROC_score_IoU_{key:.2f} for class "normal" FROC
{cls_name}_FROC_score_IoU_{key:.2f}: for class specific froc
"""
# plot normal froc curves
selection = select_froc_curves(curves)
fig, ax = get_froc_ax(self.fpi_thresholds)
for _, froc, iou in zip(*selection):
ax.plot(self.fpi_thresholds, froc, 'o-', label=f"IoU:{iou:.2f}")
ax.set_title("FROC")
ax.legend(loc='lower right')
fig.savefig(self.save_dir / "FROC.png")
plt.close(fig)
# plot cls frocs
selection = select_froc_curves_cls(curves)
reordered = defaultdict(list)
for class_name, (names, frocs, ious) in selection.items():
for froc, iou in zip(frocs, ious):
reordered[iou].append((class_name, froc))
for iou, frocs in reordered.items():
fig, ax = get_froc_ax(self.fpi_thresholds)
for class_name, froc in frocs:
ax.plot(self.fpi_thresholds, froc, 'o-', label=f"{class_name}")
title = f"FROC_cls_IoU_{iou:.2f}"
ax.set_title(title)
ax.legend(loc='lower right')
fig.savefig(self.save_dir / f"{title.replace('.', '_')}.png")
plt.close(fig)
def get_froc_ax(fpi_values: Optional[Sequence[float]] = None) -> Tuple[plt.Figure, plt.Axes]:
"""
Create preconfigured figure and axes object for froc curves
Args:
fpi_values: x values to use for froc
Returns:
plt.Figure: figure object
plt.Axes: configured axes object
"""
fig, ax = plt.subplots()
ax.set_xscale("log", base=2)
if fpi_values is not None:
ax.set_xlim(min(fpi_values), max(fpi_values))
ax.set_xticks(fpi_values)
ax.set_ylim(0, 1)
ax.set_xlabel('Avg number of false positives per scan')
ax.set_ylabel('Sensitivity')
ax.grid(True)
formatter = FuncFormatter(lambda y, _: '{:.3f}'.format(y))
ax.xaxis.set_major_formatter(formatter)
return fig, ax
def select_froc_curves(curves: Dict[str, np.ndarray], prefix: Optional[str] = None) -> \
Tuple[List[str], List[np.ndarray], List[float]]:
"""
Select froc curves
Args:
curves: dict to select frocs from. Class specific frocs need to
follow FROC_score_IoU_{key:.2f} pattern
Returns:
Dict[str, Tuple[List[str], List[np.ndarray], List[float]]]:
dict defines the classes, tuple is output from
:method:`select_froc_curves_cls`
"""
if prefix is None:
prefix = ""
froc_keys = [str(c) for c in curves.keys()
if str(c).startswith(f"{prefix}FROC_") and
not str(c).endswith("_thresholds")]
frocs = [curves[c] for c in froc_keys]
ious = [float(c.rsplit('_', 1)[1]) for c in froc_keys]
return froc_keys, frocs, ious
def select_froc_curves_cls(curves: Dict[str, np.ndarray]) -> \
Dict[str, Tuple[List[str], List[np.ndarray], List[float]]]:
"""
Select class specific froc curves
Args:
curves: dict to select frocs from. Class specific frocs need to follow
{cls_name}_FROC_score_IoU_{key:.2f} pattern
Returns:
Dict[str, Tuple[List[str], List[np.ndarray], List[float]]]:
dict defines the classes, tuple is output from
:method:`select_froc_curves_cls`
"""
all_classes = [str(c).split('_', 1)[0] for c in curves.keys()
if not str(c).startswith("FROC_") and
not str(c).endswith("_thresholds")]
all_classes = list(set(all_classes))
output = {}
for cls_name in all_classes:
output[cls_name] = select_froc_curves(curves, prefix=f"{cls_name}_")
return output
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import time
import numpy as np
from pathlib import Path
from loguru import logger
from typing import Sequence, List, Dict, Any, Tuple
import matplotlib.pyplot as plt
from nndet.evaluator import DetectionMetric
class PredictionHistogram(DetectionMetric):
def __init__(self,
classes: Sequence[str], save_dir: Path,
iou_thresholds: Sequence[float] = (0.1, 0.5),
bins: int = 50):
"""
Class to compute prediction histograms. (Note: this class does not
provide any scalar metrics)
Args:
classes: name of each class (index needs to correspond to predicted class indices!)
save_dir: directory where histograms are saved to
iou_thresholds: IoU thresholds for which FROC is evaluated
bins: number of bins of histogram
"""
self.classes = classes
self.save_dir = save_dir
self.iou_thresholds = iou_thresholds
self.bins = bins
def get_iou_thresholds(self) -> Sequence[float]:
"""
Return IoU thresholds needed for this metric in an numpy array
Returns:
Sequence[float]: IoU thresholds [M], M is the number of thresholds
"""
return self.iou_thresholds
def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple[
Dict[str, float], Dict[str, Dict[str, Any]]]:
"""
Plot class independent and per class histograms. For more info see
`method``plot_hist`
Args:
Dict: results over dataset
"""
self.plot_hist(results_list=results_list)
for cls_idx, cls_str in enumerate(self.classes):
# filter current class from list of results and put them into a dict with a single entry
results_by_cls = [{0: r[cls_idx]} for r in results_list if cls_idx in r if cls_idx in r]
self.plot_hist(results_by_cls, title_prefix=f"cl_{cls_str}_")
return {}, {}
def plot_hist(self, results_list: List[Dict[int, Dict[str, np.ndarray]]],
title_prefix: str = "") -> Tuple[
Dict[str, float], Dict[str, Dict[str, Any]]]:
"""
Compute prediction histograms for multiple IoU values
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, G], where T = number of thresholds, G = number of ground truth
`gtMatches`: matched ground truth boxes [T, D], where T = number of thresholds,
D = number of detections
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
title_prefix: prefix for title of histogram plot
Returns:
Dict: empty
Dict[Dict[str, Any]]: histogram informations
`{IoU Value}`:
`tp_hist` (np.ndarray): histogram if true positives; false negatives @ score=0 [:attr:`self.bins`]
`fp_hist` (np.ndarray): false positive histogram [:attr:`self.bins`]
`true_positives` (int): number of true positives according to matching
`false_positives` (int): number of false_positives according to matching
`false_negatives` (int): number of false_negatives according to matching
"""
num_images = len(results_list)
results = [_r for r in results_list for _r in r.values()]
if len(results) == 0:
logger.warning(f"WARNING, no results found for froc computation")
return {}, {}
# r['dtMatches'] [T, R], where R = sum(all detections)
dt_matches = np.concatenate([r['dtMatches'] for r in results], axis=1)
dt_ignores = np.concatenate([r['dtIgnore'] for r in results], axis=1)
dt_scores = np.concatenate([r['dtScores'] for r in results])
gt_ignore = np.concatenate([r['gtIgnore'] for r in results])
self.check_number_of_iou(dt_matches, dt_ignores)
num_gt = np.count_nonzero(gt_ignore == 0) # number of ground truth boxes (non ignored)
if num_gt == 0:
logger.error("No ground truth found! Returning nothing.")
return {}, {}
for iou_idx, iou_val in enumerate(self.iou_thresholds):
# filter scores with ignores detections
_scores = dt_scores[np.logical_not(dt_ignores[iou_idx])]
assert len(_scores) == len(dt_matches[iou_idx])
_ = self.compute_histogram_one_iou(\
dt_matches[iou_idx], _scores, num_images, num_gt, iou_val, title_prefix)
return {}, {}
def compute_histogram_one_iou(self, dt_matches: np.ndarray, dt_scores: np.ndarray,
num_images: int, num_gt: int, iou: float,
title_prefix: str):
"""
Plot prediction histogram
Args:
dt_matches (np.ndarray): binary array indicating which bounding
boxes have a large enough overlap with gt;
[R] where R is the number of predictions
dt_scores (np.ndarray): prediction score for each bounding box;
[R] where R is the number of predictions
num_images (int): number of images
num_gt (int): number of ground truth bounding boxes
iou: IoU values which is currently evaluated
title_prefix: prefix for title of histogram plot
"""
num_matched = np.sum(dt_matches)
false_negatives = num_gt - num_matched # false negatives
true_positives = np.sum(dt_matches)
false_positives = np.sum(dt_matches == 0)
_dt_matches = np.concatenate([dt_matches, [1] * int(false_negatives)])
_dt_scores = np.concatenate([dt_scores, [0] * int(false_negatives)])
plt.figure()
plt.yscale('log')
if 0 in dt_matches:
plt.hist(_dt_scores[_dt_matches == 0], bins=self.bins, range=(0., 1.),
alpha=0.3, color='g', label='false pos.')
if 1 in dt_matches:
plt.hist(_dt_scores[_dt_matches == 1], bins=self.bins, range=(0., 1.),
alpha=0.3, color='b', label='true pos. (false neg. @ score=0)')
plt.legend()
title = title_prefix + (f"tp:{true_positives} fp:{false_positives} "
f"fn:{false_negatives} pos:{true_positives+false_negatives}")
plt.title(title)
plt.xlabel('confidence score')
plt.ylabel('log n')
if self.save_dir is not None:
save_path = self.save_dir / (f"{title_prefix}pred_hist_IoU@{iou}".replace(".", "_") + ".png")
logger.info(f"Saving {save_path}")
plt.savefig(save_path)
plt.close()
return None
import numpy as np
from typing import Callable, Sequence, List, Dict
__all__ = ["matching_batch"]
def matching_batch(
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
iou_thresholds: Sequence[float], pred_boxes: Sequence[np.ndarray],
pred_classes: Sequence[np.ndarray], pred_scores: Sequence[np.ndarray],
gt_boxes: Sequence[np.ndarray], gt_classes: Sequence[np.ndarray],
gt_ignore: Sequence[Sequence[bool]], max_detections: int = 100,
) -> List[Dict[int, Dict[str, np.ndarray]]]:
"""
Match boxes of a batch to corresponding ground truth for each category
independently
Args:
iou_fn: compute overlap for each pair
iou_thresholds: defined which IoU thresholds should be evaluated
pred_boxes: predicted boxes from single batch; List[[D, dim * 2]],
D number of predictions
pred_classes: predicted classes from a single batch; List[[D]],
D number of predictions
pred_scores: predicted score for each bounding box; List[[D]],
D number of predictions
gt_boxes: ground truth boxes; List[[G, dim * 2]], G number of ground
truth
gt_classes: ground truth classes; List[[G]], G number of ground truth
gt_ignore: specified if which ground truth boxes are not counted as
true positives
(detections which match theses boxes are not counted as false
positives either); List[[G]], G number of ground truth
max_detections: maximum number of detections which should be evaluated
Returns:
List[Dict[int, Dict[str, np.ndarray]]]
matched detections [dtMatches] and ground truth [gtMatches]
boxes [str, np.ndarray] for each category (stored in dict keys)
for each image (list)
"""
results = []
# iterate over images/batches
for pboxes, pclasses, pscores, gboxes, gclasses, gignore in zip(
pred_boxes, pred_classes, pred_scores, gt_boxes, gt_classes, gt_ignore):
img_classes = np.union1d(pclasses, gclasses)
result = {} # dict contains results for each class in one image
for c in img_classes:
pred_mask = pclasses == c # mask predictions with current class
gt_mask = gclasses == c # mask ground trtuh with current class
if not np.any(gt_mask): # no ground truth
result[c] = _matching_no_gt(
iou_thresholds=iou_thresholds,
pred_scores=pscores[pred_mask],
max_detections=max_detections)
elif not np.any(pred_mask): # no predictions
result[c] = _matching_no_pred(
iou_thresholds=iou_thresholds,
gt_ignore=gignore[gt_mask],
)
else: # at least one prediction and one ground truth
result[c] = _matching_single_image_single_class(
iou_fn=iou_fn,
pred_boxes=pboxes[pred_mask],
pred_scores=pscores[pred_mask],
gt_boxes=gboxes[gt_mask],
gt_ignore=gignore[gt_mask],
max_detections=max_detections,
iou_thresholds=iou_thresholds,
)
results.append(result)
return results
def _matching_no_gt(
iou_thresholds: Sequence[float],
pred_scores: np.ndarray,
max_detections: int,
):
"""
Matching result with not ground truth in image
Args:
iou_thresholds: defined which IoU thresholds should be evaluated
dt_scores: predicted scores
max_detections: maximum number of allowed detections per image.
This functions uses this parameter to stay consistent with
the actual matching function which needs this limit.
Returns:
dict: computed matching
`dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number
of thresholds, G = number of ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
"""
dt_ind = np.argsort(-pred_scores, kind='mergesort')
dt_ind = dt_ind[:max_detections]
dt_scores = pred_scores[dt_ind]
num_preds = len(dt_scores)
gt_match = np.array([[]] * len(iou_thresholds))
dt_match = np.zeros((len(iou_thresholds), num_preds))
dt_ignore = np.zeros((len(iou_thresholds), num_preds))
return {
'dtMatches': dt_match, # [T, D], where T = number of thresholds, D = number of detections
'gtMatches': gt_match, # [T, G], where T = number of thresholds, G = number of ground truth
'dtScores': dt_scores, # [D] detection scores
'gtIgnore': np.array([]).reshape(-1), # [G] indicate whether ground truth should be ignored
'dtIgnore': dt_ignore, # [T, D], indicate which detections should be ignored
}
def _matching_no_pred(
iou_thresholds: Sequence[float],
gt_ignore: np.ndarray,
):
"""
Matching result with no predictions
Args:
iou_thresholds: defined which IoU thresholds should be evaluated
gt_ignore: specified if which ground truth boxes are not counted as
true positives (detections which match theses boxes are not
counted as false positives either); [G], G number of ground truth
Returns:
dict: computed matching
`dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number
of thresholds, G = number of ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
"""
dt_scores = np.array([])
dt_match = np.array([[]] * len(iou_thresholds))
dt_ignore = np.array([[]] * len(iou_thresholds))
gt_match = np.zeros((len(iou_thresholds), len(gt_ignore)))
return {
'dtMatches': dt_match, # [T, D], where T = number of thresholds, D = number of detections
'gtMatches': gt_match, # [T, G], where T = number of thresholds, G = number of ground truth
'dtScores': dt_scores, # [D] detection scores
'gtIgnore': gt_ignore.reshape(-1), # [G] indicate whether ground truth should be ignored
'dtIgnore': dt_ignore, # [T, D], indicate which detections should be ignored
}
def _matching_single_image_single_class(
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
pred_boxes: np.ndarray,
pred_scores: np.ndarray,
gt_boxes: np.ndarray,
gt_ignore: np.ndarray,
max_detections: int,
iou_thresholds: Sequence[float],
) -> Dict[str, np.ndarray]:
"""
Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
Args:
iou_fn: compute overlap for each pair
iou_thresholds: defined which IoU thresholds should be evaluated
pred_boxes: predicted boxes from single batch; [D, dim * 2], D number
of predictions
pred_scores: predicted score for each bounding box; [D], D number of
predictions
gt_boxes: ground truth boxes; [G, dim * 2], G number of ground truth
gt_ignore: specified if which ground truth boxes are not counted as
true positives (detections which match theses boxes are not
counted as false positives either); [G], G number of ground truth
max_detections: maximum number of detections which should be evaluated
Returns:
dict: computed matching
`dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number
of thresholds, G = number of ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
"""
# filter for max_detections highest scoring predictions to speed up computation
dt_ind = np.argsort(-pred_scores, kind='mergesort')
dt_ind = dt_ind[:max_detections]
pred_boxes = pred_boxes[dt_ind]
pred_scores = pred_scores[dt_ind]
# sort ignored ground truth to last positions
gt_ind = np.argsort(gt_ignore, kind='mergesort')
gt_boxes = gt_boxes[gt_ind]
gt_ignore = gt_ignore[gt_ind]
# ious between sorted(!) predictions and ground truth
ious = iou_fn(pred_boxes, gt_boxes)
num_preds, num_gts = ious.shape[0], ious.shape[1]
gt_match = np.zeros((len(iou_thresholds), num_gts))
dt_match = np.zeros((len(iou_thresholds), num_preds))
dt_ignore = np.zeros((len(iou_thresholds), num_preds))
for tind, t in enumerate(iou_thresholds):
for dind, _d in enumerate(pred_boxes): # iterate detections starting from highest scoring one
# information about best match so far (m=-1 -> unmatched)
iou = min([t, 1-1e-10])
m = -1
for gind, _g in enumerate(gt_boxes): # iterate ground truth
# if this gt already matched, continue
if gt_match[tind, gind] > 0:
continue
# if dt matched to reg gt, and on ignore gt, stop
if m > -1 and gt_ignore[m] == 0 and gt_ignore[gind] == 1:
break
# continue to next gt unless better match made
if ious[dind, gind] < iou:
continue
# if match successful and best so far, store appropriately
iou = ious[dind, gind]
m = gind
# if match made, store id of match for both dt and gt
if m == -1:
continue
else:
dt_ignore[tind, dind] = int(gt_ignore[m])
dt_match[tind, dind] = 1
gt_match[tind, m] = 1
# store results for given image and category
return {
'dtMatches': dt_match, # [T, D], where T = number of thresholds, D = number of detections
'gtMatches': gt_match, # [T, G], where T = number of thresholds, G = number of ground truth
'dtScores': pred_scores, # [D] detection scores
'gtIgnore': gt_ignore.reshape(-1), # [G] indicate whether ground truth should be ignored
'dtIgnore': dt_ignore, # [T, D], indicate which detections should be ignored
}
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from os import PathLike
from pathlib import Path
from typing import Dict, Sequence, Optional, Tuple
import numpy as np
from loguru import logger
from nndet.io.load import load_pickle, save_json, save_pickle
from nndet.evaluator.det import BoxEvaluator
from nndet.evaluator.case import CaseEvaluator
from nndet.evaluator.seg import PerCaseSegmentationEvaluator
def save_metric_output(scores, curves, base_dir, name):
"""
Helper function to save output of the function in a nice format
"""
scores_string = {str(key): str(item) for key, item in scores.items()}
save_json(scores_string, base_dir / f"{name}.json")
save_pickle({"scores": scores, "curves": curves}, base_dir / f"{name}.pkl")
def evaluate_box_dir(
pred_dir: PathLike,
gt_dir: PathLike,
classes: Sequence[str],
save_dir: Optional[Path] = None,
) -> Tuple[Dict, Dict]:
"""
Run box evaluation inside a directory
Args:
pred_dir: path to dir with predictions
gt_dir: path to dir with groud truth data
classes: classes present in dataset
save_dir: optional path to save plots
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
See Also:
:class:`nndet.evaluator.registry.BoxEvaluator`
"""
pred_dir = Path(pred_dir)
gt_dir = Path(gt_dir)
if save_dir is not None:
save_dir.mkdir(parents=True, exist_ok=True)
case_ids = [p.stem.rsplit('_boxes', 1)[0] for p in pred_dir.iterdir()
if p.is_file() and p.stem.endswith("_boxes")]
logger.info(f"Found {len(case_ids)} for box evaluation in {pred_dir}")
evaluator = BoxEvaluator.create(classes=classes,
fast=False,
verbose=False,
save_dir=save_dir,
)
for case_id in case_ids:
gt = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
pred = load_pickle(pred_dir / f"{case_id}_boxes.pkl")
evaluator.run_online_evaluation(
pred_boxes=[pred["pred_boxes"]], pred_classes=[pred["pred_labels"]],
pred_scores=[pred["pred_scores"]], gt_boxes=[gt["boxes"]],
gt_classes=[gt["classes"]], gt_ignore=None,
)
return evaluator.finish_online_evaluation()
def evaluate_case_dir(
pred_dir: PathLike,
gt_dir: PathLike,
classes: Sequence[str],
target_class: Optional[int] = None,
) -> Tuple[Dict, Dict]:
"""
Run evaluation of case results inside a directory
Args:
pred_dir: path to dir with predictions
gt_dir: path to dir with groud truth data
classes: classes present in dataset
target_class in case of multiple classes, specify a target class
to evaluate in a target class vs rest setting
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graph)
See Also:
:class:`nndet.evaluator.registry.CaseEvaluator`
"""
pred_dir = Path(pred_dir)
gt_dir = Path(gt_dir)
case_ids = [p.stem.rsplit('_boxes', 1)[0] for p in pred_dir.iterdir()
if p.is_file() and p.stem.endswith("_boxes")]
logger.info(f"Found {len(case_ids)} for case evaluation in {pred_dir}")
evaluator = CaseEvaluator.create(classes=classes,
target_class=target_class,
)
for case_id in case_ids:
gt = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
pred = load_pickle(pred_dir / f"{case_id}_boxes.pkl")
evaluator.run_online_evaluation(
pred_classes=[pred["pred_labels"]],
pred_scores=[pred["pred_scores"]],
gt_classes=[gt["classes"]]
)
return evaluator.finish_online_evaluation()
def evaluate_seg_dir(
pred_dir: PathLike,
gt_dir: PathLike,
classes: Sequence[str],
) -> Tuple[Dict, None]:
"""
Compute dice metric across a directory
Args:
pred_dir: path to dir with predictions
gt_dir: path to dir with groud truth data
classes: classes present in dataset
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
None
See Also:
:class:`nndet.evaluator.registry.PerCaseSegmentationEvaluator`
"""
pred_dir = Path(pred_dir)
gt_dir = Path(gt_dir)
case_ids = [p.stem.rsplit('_seg', 1)[0] for p in pred_dir.iterdir()
if p.is_file() and p.stem.endswith("_seg")]
logger.info(f"Found {len(case_ids)} for seg evaluation in {pred_dir}")
evaluator = PerCaseSegmentationEvaluator.create(classes=classes)
for case_id in case_ids:
gt = np.load(str(gt_dir / f"{case_id}_seg_gt.npz"), allow_pickle=True)["seg"] # 1, dims
pred = load_pickle(pred_dir / f"{case_id}_seg.pkl")
evaluator.run_online_evaluation(
seg=pred[None],
target=gt,
)
return evaluator.finish_online_evaluation()
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from typing import Dict, Sequence, Tuple
from collections import defaultdict
from nndet.evaluator import AbstractEvaluator
__all__ = ["SegmentationEvaluator"]
class SegmentationEvaluator(AbstractEvaluator):
def __init__(self,
per_class: bool = True,
*args,
**kwargs,
):
"""
Compute dice score during training
"""
self.per_class = per_class
self.results_list = defaultdict(list)
def reset(self):
"""
Reset internal state for new epoch
"""
self.results_list = defaultdict(list)
def run_online_evaluation(self,
seg_probs: np.ndarray,
target: np.ndarray,
) -> Dict:
"""
Run evaluation of one batch and save internal results for later
Args:
seg_probs: output probabilities of network [N, C, dims], where N
is the batch size, C is the number of classes, dims are
spatial dimensions
target: ground truth segmentation [N, dims], where N is the batch
size and dims are spatial dimensions
Returns:
Dict: empty dict
"""
num_classes = seg_probs.shape[1]
output_seg = np.argmax(seg_probs, axis=1).reshape((seg_probs.shape[0], -1))
target = target.reshape((target.shape[0], -1))
tp_hard = np.zeros((target.shape[0], num_classes - 1))
fp_hard = np.zeros((target.shape[0], num_classes - 1))
fn_hard = np.zeros((target.shape[0], num_classes - 1))
for c in range(1, num_classes):
tp_hard[:, c - 1] = ((output_seg == c).astype(np.float32) * (target == c).astype(np.float32)).sum(axis=1)
fp_hard[:, c - 1] = ((output_seg == c).astype(np.float32) * (target != c).astype(np.float32)).sum(axis=1)
fn_hard[:, c - 1] = ((output_seg != c).astype(np.float32) * (target == c).astype(np.float32)).sum(axis=1)
tp_hard = tp_hard.sum(axis=0)
fp_hard = fp_hard.sum(axis=0)
fn_hard = fn_hard.sum(axis=0)
self.results_list["fg_dice"] = list(
(2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))
self.results_list["tp"].append(tp_hard)
self.results_list["fp"].append(fp_hard)
self.results_list["fn"].append(fn_hard)
return {}
def finish_online_evaluation(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Summarize results from batches and compute global dice and global
dice per class
Returns:
Dict: results
`{cls_idx}_seg_dice`: global dice per class
`seg_dice`: global dice over all classes
"""
results = {}
if self.results_list:
tp = np.sum(self.results_list["tp"], 0)
fp = np.sum(self.results_list["fp"], 0)
fn = np.sum(self.results_list["fn"], 0)
global_dc_per_class = [
i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(tp, fp, fn)] if not np.isnan(i)]
if self.per_class:
for cls_idx, dc in enumerate(global_dc_per_class):
results[f"{cls_idx}_seg_dice"] = dc
results["seg_dice"] = np.mean(global_dc_per_class)
return results, None
@classmethod
def create(cls,
per_class: bool = False,
):
return cls(per_class=per_class)
class PerCaseSegmentationEvaluator(AbstractEvaluator):
def __init__(self,
classes: Sequence[str],
*args,
**kwargs,
):
"""
Compute dice score per case and average results over dataset
"""
self.classes = classes
self.results = []
def reset(self):
"""
Reset internal state for new epoch
"""
self.results = []
def run_online_evaluation(self,
seg: np.ndarray,
target: np.ndarray,
) -> Dict:
"""
Run evaluation of one batch and save internal results for later
Args:
seg: output segmentation [N, dims]
target: ground truth segmentation [N, dims], where N is the batch
size and dims are spatial dimensions
Returns:
Dict: empty dict
"""
assert len(seg) == len(target)
num_classes = len(self.classes)
output_seg = seg.reshape((seg.shape[0], -1)) # N, X
target = target.reshape((target.shape[0], -1)) # N, X
tp_hard = np.zeros((target.shape[0], num_classes - 1)) # N, FG
fp_hard = np.zeros((target.shape[0], num_classes - 1)) # N ,FG
fn_hard = np.zeros((target.shape[0], num_classes - 1)) # N, FG
fg_present = np.zeros((target.shape[0], num_classes - 1)) # N, FG
for c in range(1, num_classes):
tp_hard[:, c - 1] = ((output_seg == c).astype(np.float32) * (target == c).astype(np.float32)).sum(axis=1)
fp_hard[:, c - 1] = ((output_seg == c).astype(np.float32) * (target != c).astype(np.float32)).sum(axis=1)
fn_hard[:, c - 1] = ((output_seg != c).astype(np.float32) * (target == c).astype(np.float32)).sum(axis=1)
fg_present[:, c - 1] = (target == c).any(axis=1).astype(np.int32)
dice = np.where(fg_present, 2. * tp_hard / (2 * tp_hard + fp_hard + fn_hard), np.nan) # N, FG
self.results.append(dice)
return {}
def finish_online_evaluation(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Summarize results from batches and compute global dice and global
dice per class
Returns:
Dict: results
`{cls_idx}_seg_dice`: global dice per class
`seg_dice`: global dice over all classes
"""
dice_full = np.concatenate(self.results, axis=0)
dice_per_class = dice_full.mean(axies=0) # C
dice = dice_full.mean() # 1
results = {}
for cls_idx, value in enumerate(dice_per_class):
results[f"dice_cls_{cls_idx}"] = float(value)
results["dice"] = float(dice)
return results, None
@classmethod
def create(cls,
classes: Sequence[str],
):
return cls(classes=classes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment