Commit 7246044d authored by mibaumgartner's avatar mibaumgartner
Browse files

Merge remote-tracking branch 'origin/master' into main

parents fcec502f 6f4c3333
from nndet.evaluator.abstract import AbstractMetric, AbstractEvaluator, DetectionMetric
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import abstractmethod, ABC
import numpy as np
from typing import Dict, List, Sequence
__all__ = ["AbstractEvaluator", "AbstractMetric", "DetectionMetric"]
class AbstractEvaluator(ABC):
@abstractmethod
def run_online_evaluation(self, *args, **kwargs):
"""
Compute necessary values per batch for later evaluation
"""
raise NotImplementedError
@abstractmethod
def finish_online_evaluation(self, *args, **kwargs):
"""
Accumulate results from batches and compute metrics
"""
raise NotImplementedError
@abstractmethod
def reset(self):
"""
Reset internal state of evaluator
"""
raise NotImplementedError
class AbstractMetric(ABC):
def __call__(self, *args, **kwargs) -> (Dict[str, float], Dict[str, np.ndarray]):
"""
Compute metric. See :func:`compute` for more information.
Args:
*args: positional arguments passed to :func:`compute`
**kwargs: keyword arguments passed to :func:`compute`
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
"""
return self.compute(*args, **kwargs)
@abstractmethod
def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> (
Dict[str, float], Dict[str, np.ndarray]):
"""
Compute metric
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, G], where T = number of thresholds, G = number of ground truth
`gtMatches`: matched ground truth boxes [T, D], where T = number of thresholds,
D = number of detections
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
"""
raise NotImplementedError
class DetectionMetric(AbstractMetric):
@abstractmethod
def get_iou_thresholds(self) -> Sequence[float]:
"""
Return IoU thresholds needed for this metric in an numpy array
Returns:
Sequence[float]: IoU thresholds; [M], M is the number of thresholds
"""
raise NotImplementedError
def check_number_of_iou(self, *args) -> None:
"""
Check if shape of input in first dimension is consistent with expected IoU values
(assumes IoU dimension is the first dimension)
Args:
args: array like inputs with shape function
"""
num_ious = len(self.get_iou_thresholds())
for arg in args:
assert arg.shape[0] == num_ious
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import defaultdict
from typing import Dict, Sequence, Callable, Tuple, Union, Mapping, Optional
import numpy as np
from loguru import logger
from sklearn.metrics import accuracy_score, average_precision_score, confusion_matrix, \
f1_score, precision_score, recall_score, roc_auc_score
from nndet.evaluator import AbstractEvaluator
__all__ = ["CaseEvaluator"]
class _CaseEvaluator(AbstractEvaluator):
def __init__(self,
classes: Sequence[Union[str, int]],
score_metrics_scalar: Mapping[str, Callable] = None,
class_metrics_scalar: Mapping[str, Callable] = None,
score_metrics_curve: Mapping[str, Callable] = None,
class_metrics_curve: Mapping[str, Callable] = None,
target_class: Optional[int] = None,
):
"""
Compute case level evaluation metrics
Predictions for individual instances are aggregated by using the
max of the predicted score for each class. Final class prediction
is computed by an argmax over that scores. The mappings of the
metrics are later used as the keys of the result dict.
Args:
classes: class present in whole dataset
score_metrics_scalar: metrics which accept ground truth classes [N]
and prediction scores [N, C] for evaluation; N is the nunber
of cases and C is the number of classes. The output should
be a scalar.
class_metrics_scalar: metrics which accept ground truth classes [N]
and prediction classes [N] for evaluation; N is the nunber
of cases and C is the number of classes. The output should
be a scalar.
score_metrics_curve: metrics which accept ground truth classes [N]
and prediction scores [N, C] for evaluation; N is the nunber
of cases and C is the number of classes. The output should be
an array like object.
class_metrics_curve: metrics which accept ground truth classes [N]
and prediction classes [N] for evaluation; N is the nunber
of cases and C is the number of classes. The output should
be an array like object.
target_class: target class for case evaluation (internally
results are evaluated in a binary case target class vs rest).
If None, fall back to fg vs bg
"""
self.results_list = defaultdict(list)
self.score_metrics_scalar = score_metrics_scalar if score_metrics_scalar is not None else {}
self.class_metrics_scalar = class_metrics_scalar if class_metrics_scalar is not None else {}
self.score_metrics_curve = score_metrics_curve if score_metrics_curve is not None else {}
self.class_metrics_curve = class_metrics_curve if class_metrics_curve is not None else {}
if isinstance(target_class, str):
raise ValueError(f"Need integer value of target class not the name!")
self.target_class = int(target_class)
self.classes = classes
self.num_classes = len(classes)
def reset(self):
"""
Reset internal state for new epoch
"""
self.results_list = defaultdict(list)
def run_online_evaluation(self,
pred_classes: Sequence[np.ndarray],
pred_scores: Sequence[np.ndarray],
gt_classes: Sequence[np.ndarray],
) -> Dict:
"""
Run evaluation on each case (accepts a batch of case resutls
at once).
Args:
pred_classes (Sequence[np.ndarray]): predicted classes from a batch
of cases; List[[D]], D number of predictions
pred_scores (Sequence[np.ndarray]): predicted score for each
bounding box; List[[D]], D number of predictions
gt_classes (Sequence[np.ndarray]): ground truth classes for each
instance in a case; List[[G]], G number of ground truth
Returns:
Dict: empty dict
Notes:
This caches the max predicted probability per class per element
and the unique classes present per element.
"""
case_classes = [np.unique(gtc) for gtc in gt_classes]
case_scores = []
for case_instance_scores, case_instance_classes in zip(pred_scores, pred_classes):
_scores = np.zeros(self.num_classes)
for instance_score, instance_class in zip(case_instance_scores, case_instance_classes):
if _scores[int(instance_class)] < instance_score:
_scores[int(instance_class)] = instance_score
case_scores.append(_scores)
self.results_list["case_classes"].extend(case_classes)
self.results_list["case_scores"].extend(case_scores)
return {}
def finish_online_evaluation(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Compute final scores and curves of metrics
Returns:
Dict: results of scalar metrics
Dict: results of curve metrics
"""
# aggregate cases
gt_classes = self.aggregate_classes()
pred_scores, pred_classes = self.aggregate_prdictions()
# compute metrics
curve_results = {}
for key, metric in self.score_metrics_curve.items():
curve_results[key] = metric(gt_classes, pred_scores)
for key, metric in self.class_metrics_curve.items():
curve_results[key] = metric(gt_classes, pred_classes)
scalar_results = {}
for key, metric in self.score_metrics_scalar.items():
try:
scalar_results[key] = metric(gt_classes, pred_scores)
except (ValueError, RuntimeError) as e:
logger.warning(f"Metric {key} exited with error {e}; writing nan to result")
scalar_results[key] = np.nan
for key, metric in self.class_metrics_scalar.items():
try:
scalar_results[key] = metric(gt_classes, pred_classes)
except (ValueError, RuntimeError) as e:
logger.warning(f"Metric {key} exited with error {e}; writing nan to result")
scalar_results[key] = np.nan
return scalar_results, curve_results
def aggregate_classes(self) -> np.ndarray:
"""
Aggregate classes of each instance in a case to one case class
Returns:
np.ndarray: class per case [N], where N is the number of cases
"""
if self.target_class is not None:
gt_classes = np.asarray(
[int(self.target_class in cc) for cc in self.results_list["case_classes"]])
else:
gt_classes = np.asarray(
[1 if len(cc) > 0 else 0 for cc in self.results_list["case_classes"]])
return gt_classes
def aggregate_prdictions(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Aggreagte prediction scores per class to case scores with target class
Returns:
np.ndarray: predicted scores
np.ndarray: predicted classes
"""
_pred_scores = np.stack(self.results_list["case_scores"], axis=0) # N, num_classes
if self.target_class is not None:
pred_scores = _pred_scores[:, self.target_class] # N
# pred_classes = (np.argmax(_pred_scores, axis=1) == self.target_class).astype(np.int32) # N
# This is not always the correct choice, depending on the final
# nonlinearity of the network (sigmoid vs. softmax)
pred_classes = (pred_scores > 0.5).astype(np.int32) # N
else:
pred_scores = _pred_scores.max(axis=1) # N
pred_classes = (pred_scores > 0.5).astype(np.int32) # N
return pred_scores, pred_classes
class CaseEvaluator(_CaseEvaluator):
@classmethod
def create(cls,
classes: Sequence[str],
target_class: int = None
):
"""
Evaluation on patient level
Args:
classes: classes present in dataset
target_class: if multiple classes are given, define
a target class to evaluate in an target_class vs rest setting.
Defaults to None.
Returns:
CaseEvaluator: evaluator
"""
# if len(classes) > 2 and target_class is None:
# f1_fn = partial(f1_score, average="macro")
# prec_fn = partial(precision_score, average="macro")
# rec_fn = partial(recall_score, average="macro")
# else:
f1_fn = f1_score
prec_fn = precision_score
rec_fn = recall_score
score_metrics_scalar = {"auc_case": roc_auc_score, "ap_case": average_precision_score}
class_metrics_scalar = {"f1_case": f1_fn, "prec_case": prec_fn,
"rec_case": rec_fn, "acc_case": accuracy_score}
score_metrics_curve = {}
class_metrics_curve = {"cfm_case": confusion_matrix}
return cls(classes=classes,
score_metrics_scalar=score_metrics_scalar,
class_metrics_scalar=class_metrics_scalar,
score_metrics_curve=score_metrics_curve,
class_metrics_curve=class_metrics_curve,
target_class=target_class,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from pathlib import Path
from functools import partial
from typing import Optional, Sequence, Callable, Dict, List, Tuple
import numpy as np
from nndet.evaluator.abstract import AbstractEvaluator, DetectionMetric
from nndet.evaluator.detection.matching import matching_batch
from nndet.core.boxes import box_iou_np
from nndet.evaluator.detection.coco import COCOMetric
from nndet.evaluator.detection.froc import FROCMetric
from nndet.evaluator.detection.hist import PredictionHistogram
__all__ = ["DetectionEvaluator"]
class DetectionEvaluator(AbstractEvaluator):
def __init__(self,
metrics: Sequence[DetectionMetric],
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = box_iou_np,
max_detections: int = 100,
):
"""
Class for evaluate detection metrics
Args:
metrics (Sequence[DetectionMetric]: detection metrics to evaluate
iou_fn (Callable[[np.ndarray, np.ndarray], np.ndarray]): compute overlap for each pair
max_detections (int): number of maximum detections per image (reduces computation)
"""
self.iou_fn = iou_fn
self.max_detections = max_detections
self.metrics = metrics
self.results_list = [] # store results of each image
self.iou_thresholds = self.get_unique_iou_thresholds()
self.iou_mapping = self.get_indices_of_iou_for_each_metric()
def get_unique_iou_thresholds(self):
"""
Compute unique set of iou thresholds
"""
iou_thresholds = [_i for i in self.metrics for _i in i.get_iou_thresholds()]
iou_thresholds = list(set(iou_thresholds))
iou_thresholds.sort()
return iou_thresholds
def get_indices_of_iou_for_each_metric(self):
"""
Find indices of iou thresholds for each metric
"""
return [[self.iou_thresholds.index(th) for th in m.get_iou_thresholds()]
for m in self.metrics]
def run_online_evaluation(self,
pred_boxes: Sequence[np.ndarray],
pred_classes: Sequence[np.ndarray],
pred_scores: Sequence[np.ndarray],
gt_boxes: Sequence[np.ndarray],
gt_classes: Sequence[np.ndarray],
gt_ignore: Sequence[Sequence[bool]] = None) -> Dict:
"""
Preprocess batch results for final evaluation
Args:
pred_boxes (Sequence[np.ndarray]): predicted boxes from single batch; List[[D, dim * 2]], D number of
predictions
pred_classes (Sequence[np.ndarray]): predicted classes from a single batch; List[[D]], D number of
predictions
pred_scores (Sequence[np.ndarray]): predicted score for each bounding box; List[[D]], D number of
predictions
gt_boxes (Sequence[np.ndarray]): ground truth boxes; List[[G, dim * 2]], G number of ground truth
gt_classes (Sequence[np.ndarray]): ground truth classes; List[[G]], G number of ground truth
gt_ignore (Sequence[Sequence[bool]]): specified if which ground truth boxes are not counted as true
positives (detections which match theses boxes are not counted as false positives either);
List[[G]], G number of ground truth
Returns
dict: empty dict... detection metrics can only be evaluated at the end
"""
if gt_ignore is None:
gt_ignore = [np.zeros(gt_boxes_img.shape[0]).reshape(-1) for gt_boxes_img in gt_boxes]
self.results_list.extend(matching_batch(
self.iou_fn, self.iou_thresholds, pred_boxes=pred_boxes, pred_classes=pred_classes,
pred_scores=pred_scores, gt_boxes=gt_boxes, gt_classes=gt_classes, gt_ignore=gt_ignore,
max_detections=self.max_detections))
return {}
def finish_online_evaluation(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Accumulate results of individual batches and compute final metrics
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
"""
metric_scores = {}
metric_curves = {}
for metric_idx, metric in enumerate(self.metrics):
_filter = partial(self.iou_filter, iou_idx=self.iou_mapping[metric_idx])
iou_filtered_results = list(map(_filter, self.results_list))
score, curve = metric(iou_filtered_results)
if score is not None:
metric_scores.update(score)
if curve is not None:
metric_curves.update(curve)
return metric_scores, metric_curves
@staticmethod
def iou_filter(image_dict: Dict[int, Dict[str, np.ndarray]], iou_idx: List[int],
filter_keys: Sequence[str] = ('dtMatches', 'gtMatches', 'dtIgnore')):
"""
This functions can be used to filter specific IoU values from the results
to make sure that the correct IoUs are passed to metric
Parameters
----------
image_dict : dict
dictionary containin :param:`filter_keys` which contains IoUs in the first dimension
iou_idx : List[int]
indices of IoU values to filter from keys
filter_keys : tuple, optional
keys to filter, by default ('dtMatches', 'gtMatches', 'dtIgnore')
Returns
-------
dict
filtered dictionary
"""
iou_idx = list(iou_idx)
filtered = {}
for cls_key, cls_item in image_dict.items():
filtered[cls_key] = {key: item[iou_idx] if key in filter_keys else item
for key, item in cls_item.items()}
return filtered
def reset(self):
"""
Reset internal state of evaluator
"""
self.results_list = []
class BoxEvaluator(DetectionEvaluator):
@classmethod
def create(cls,
classes: Sequence[str],
fast: bool = True,
verbose: bool = False,
save_dir: Optional[Path] = None,
):
"""
Create an box evaluator object
Args:
classes: classes present in the dataset
fast: Reduces the evaluation suite to save time.
Only evaluated IoUs in the range of 0.1-0.5
Does no calculate pre class metrics
verbose: Additional logging output
save_dir: Path to save information
Returns:
BoxEvaluator: evaluator to efficiently compute metrics
"""
iou_fn = box_iou_np
iou_range = (0.1, 0.5, 0.05)
iou_thresholds = (0.1, 0.5) if fast else np.arange(0.1, 1.0, 0.1)
per_class = False if fast else True
metrics = []
metrics.append(
FROCMetric(classes,
iou_thresholds=iou_thresholds,
fpi_thresholds=(1/8, 1/4, 1/2, 1, 2, 4, 8),
per_class=per_class,
verbose=verbose,
save_dir= None if fast else save_dir
)
)
metrics.append(
COCOMetric(classes,
iou_list=iou_thresholds,
iou_range=iou_range,
max_detection=(100, ),
per_class=per_class,
verbose=verbose,
)
)
if not fast:
metrics.append(
PredictionHistogram(classes=classes,
save_dir=save_dir,
iou_thresholds=(0.1, 0.5),
)
)
return cls(metrics=tuple(metrics), iou_fn=iou_fn)
from nndet.evaluator.detection.froc import FROCMetric
from nndet.evaluator.detection.coco import COCOMetric
from nndet.evaluator.detection.hist import PredictionHistogram
import time
import numpy as np
from loguru import logger
from typing import Sequence, List, Dict, Union, Tuple
from nndet.evaluator import DetectionMetric
class COCOMetric(DetectionMetric):
def __init__(self,
classes: Sequence[str],
iou_list: Sequence[float] = (0.1, 0.5, 0.75),
iou_range: Sequence[float] = (0.1, 0.5, 0.05),
max_detection: Sequence[int] = (1, 5, 100),
per_class: bool = True,
verbose: bool = True):
"""
Class to compute COCO metrics
Metrics computed:
mAP over the IoU range specified by :param:`iou_range` at last value of :param:`max_detection`
AP values at IoU thresholds specified by :param:`iou_list` at last value of :param:`max_detection`
AR over max detections thresholds defined by :param:`max_detection` (over iou range)
Args:
classes (Sequence[str]): name of each class (index needs to correspond to predicted class indices!)
iou_list (Sequence[float]): specific thresholds where ap is evaluated and saved
iou_range (Sequence[float]): (start, stop, step) for mAP iou thresholds
max_detection (Sequence[int]): maximum number of detections per image
verbose (bool): log time needed for evaluation
"""
self.verbose = verbose
self.classes = classes
self.per_class = per_class
iou_list = np.array(iou_list)
_iou_range = np.linspace(iou_range[0], iou_range[1],
int(np.round((iou_range[1] - iou_range[0]) / iou_range[2])) + 1, endpoint=True)
self.iou_thresholds = np.union1d(iou_list, _iou_range)
self.iou_range = iou_range
# get indices of iou values of ious range and ious list for later evaluation
self.iou_list_idx = np.nonzero(iou_list[:, np.newaxis] == self.iou_thresholds[np.newaxis])[1]
self.iou_range_idx = np.nonzero(_iou_range[:, np.newaxis] == self.iou_thresholds[np.newaxis])[1]
assert (self.iou_thresholds[self.iou_list_idx] == iou_list).all()
assert (self.iou_thresholds[self.iou_range_idx] == _iou_range).all()
self.recall_thresholds = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
self.max_detections = max_detection
def get_iou_thresholds(self) -> Sequence[float]:
"""
Return IoU thresholds needed for this metric in an numpy array
Returns:
Sequence[float]: IoU thresholds [M], M is the number of thresholds
"""
return self.iou_thresholds
def compute(self,
results_list: List[Dict[int, Dict[str, np.ndarray]]],
) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Compute COCO metrics
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number of thresholds, G = number of
ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
Returns:
Dict[str, float]: dictionary with coco metrics
Dict[str, np.ndarray]: None
"""
if self.verbose:
logger.info('Start COCO metric computation...')
tic = time.time()
dataset_statistics = self.compute_statistics(results_list=results_list)
if self.verbose:
toc = time.time()
logger.info(f'Statistics for COCO metrics finished (t={(toc - tic):0.2f}s).')
results = {}
results.update(self.compute_ap(dataset_statistics))
results.update(self.compute_ar(dataset_statistics))
if self.verbose:
toc = time.time()
logger.info(f'COCO metrics computed in t={(toc - tic):0.2f}s.')
return results, None
def compute_ap(self, dataset_statistics: dict) -> dict:
"""
Compute AP metrics
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number of thresholds, G = number of
ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
"""
results = {}
if self.iou_range: # mAP
key = (f"mAP_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_"
f"MaxDet_{self.max_detections[-1]}")
results[key] = self.select_ap(dataset_statistics, iou_idx=self.iou_range_idx, max_det_idx=-1)
if self.per_class:
for cls_idx, cls_str in enumerate(self.classes): # per class results
key = (f"{cls_str}_"
f"mAP_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_"
f"MaxDet_{self.max_detections[-1]}")
results[key] = self.select_ap(dataset_statistics, iou_idx=self.iou_range_idx,
cls_idx=cls_idx, max_det_idx=-1)
for idx in self.iou_list_idx: # AP@IoU
key = f"AP_IoU_{self.iou_thresholds[idx]:.2f}_MaxDet_{self.max_detections[-1]}"
results[key] = self.select_ap(dataset_statistics, iou_idx=[idx], max_det_idx=-1)
if self.per_class:
for cls_idx, cls_str in enumerate(self.classes): # per class results
key = (f"{cls_str}_"
f"AP_IoU_{self.iou_thresholds[idx]:.2f}_"
f"MaxDet_{self.max_detections[-1]}")
results[key] = self.select_ap(dataset_statistics,
iou_idx=[idx], cls_idx=cls_idx, max_det_idx=-1)
return results
def compute_ar(self, dataset_statistics: dict) -> dict:
"""
Compute AR metrics
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number of thresholds, G = number of
ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
"""
results = {}
for max_det_idx, max_det in enumerate(self.max_detections): # mAR
key = f"mAR_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_MaxDet_{max_det}"
results[key] = self.select_ar(dataset_statistics, max_det_idx=max_det_idx)
if self.per_class:
for cls_idx, cls_str in enumerate(self.classes): # per class results
key = (f"{cls_str}_"
f"mAR_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_"
f"MaxDet_{max_det}")
results[key] = self.select_ar(dataset_statistics,
cls_idx=cls_idx, max_det_idx=max_det_idx)
for idx in self.iou_list_idx: # AR@IoU
key = f"AR_IoU_{self.iou_thresholds[idx]:.2f}_MaxDet_{self.max_detections[-1]}"
results[key] = self.select_ar(dataset_statistics, iou_idx=idx, max_det_idx=-1)
if self.per_class:
for cls_idx, cls_str in enumerate(self.classes): # per class results
key = (f"{cls_str}_"
f"AR_IoU_{self.iou_thresholds[idx]:.2f}_"
f"MaxDet_{self.max_detections[-1]}")
results[key] = self.select_ar(dataset_statistics, iou_idx=idx,
cls_idx=cls_idx, max_det_idx=-1)
return results
@staticmethod
def select_ap(dataset_statistics: dict, iou_idx: Union[int, List[int]] = None,
cls_idx: Union[int, Sequence[int]] = None, max_det_idx: int = -1) -> np.ndarray:
"""
Compute average precision
Args:
dataset_statistics (dict): computed statistics over dataset
`counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
detection thresholds
`recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
`precision`: Precision values at specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
`scores`: Scores corresponding to specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
iou_idx: index of IoU values to select for evaluation(if None, all values are used)
cls_idx: class indices to select, if None all classes will be selected
max_det_idx (int): index to select max detection threshold from data
Returns:
np.ndarray: AP value
"""
prec = dataset_statistics["precision"]
if iou_idx is not None:
prec = prec[iou_idx]
if cls_idx is not None:
prec = prec[..., cls_idx, :]
prec = prec[..., max_det_idx]
return np.mean(prec)
@staticmethod
def select_ar(dataset_statistics: dict, iou_idx: Union[int, Sequence[int]] = None,
cls_idx: Union[int, Sequence[int]] = None,
max_det_idx: int = -1) -> np.ndarray:
"""
Compute average recall
Args:
dataset_statistics (dict): computed statistics over dataset
`counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
detection thresholds
`recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
`precision`: Precision values at specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
`scores`: Scores corresponding to specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
iou_idx: index of IoU values to select for evaluation(if None, all values are used)
cls_idx: class indices to select, if None all classes will be selected
max_det_idx (int): index to select max detection threshold from data
Returns:
np.ndarray: recall value
"""
rec = dataset_statistics["recall"]
if iou_idx is not None:
rec = rec[iou_idx]
if cls_idx is not None:
rec = rec[..., cls_idx, :]
rec = rec[..., max_det_idx]
if len(rec[rec > -1]) == 0:
rec = -1
else:
rec = np.mean(rec[rec > -1])
return rec
def compute_statistics(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]
) -> Dict[str, Union[np.ndarray, List]]:
"""
Compute statistics needed for COCO metrics (mAP, AP of individual classes, mAP@IoU_Thresholds, AR)
Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per cateory (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number of thresholds, G = number of
ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth should be
ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
Returns:
dict: computed statistics over dataset
`counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
detection thresholds
`recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
`precision`: Precision values at specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
`scores`: Scores corresponding to specified recall thresholds
[num_iou_th, num_recall_th, num_classes, num_max_detections]
"""
num_iou_th = len(self.iou_thresholds)
num_recall_th = len(self.recall_thresholds)
num_classes = len(self.classes)
num_max_detections = len(self.max_detections)
# -1 for the precision of absent categories
precision = -np.ones((num_iou_th, num_recall_th, num_classes, num_max_detections))
recall = -np.ones((num_iou_th, num_classes, num_max_detections))
scores = -np.ones((num_iou_th, num_recall_th, num_classes, num_max_detections))
for cls_idx, cls_i in enumerate(self.classes): # for each class
for maxDet_idx, maxDet in enumerate(self.max_detections): # for each maximum number of detections
results = [r[cls_idx] for r in results_list if cls_idx in r]
if len(results) == 0:
logger.warning(f"WARNING, no results found for coco metric for class {cls_i}")
continue
dt_scores = np.concatenate([r['dtScores'][0:maxDet] for r in results])
# different sorting method generates slightly different results.
# mergesort is used to be consistent as Matlab implementation.
inds = np.argsort(-dt_scores, kind='mergesort')
dt_scores_sorted = dt_scores[inds]
# r['dtMatches'] [T, R], where R = sum(all detections)
dt_matches = np.concatenate([r['dtMatches'][:, 0:maxDet] for r in results], axis=1)[:, inds]
dt_ignores = np.concatenate([r['dtIgnore'][:, 0:maxDet] for r in results], axis=1)[:, inds]
self.check_number_of_iou(dt_matches, dt_ignores)
gt_ignore = np.concatenate([r['gtIgnore'] for r in results])
num_gt = np.count_nonzero(gt_ignore == 0) # number of ground truth boxes (non ignored)
if num_gt == 0:
logger.warning(f"WARNING, no gt found for coco metric for class {cls_i}")
continue
# ignore cases need to be handled differently for tp and fp
tps = np.logical_and(dt_matches, np.logical_not(dt_ignores))
fps = np.logical_and(np.logical_not(dt_matches), np.logical_not(dt_ignores))
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float32)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float32)
for th_ind, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): # for each threshold th_ind
tp, fp = np.array(tp), np.array(fp)
r, p, s = compute_stats_single_threshold(tp, fp, dt_scores_sorted, self.recall_thresholds, num_gt)
recall[th_ind, cls_idx, maxDet_idx] = r
precision[th_ind, :, cls_idx, maxDet_idx] = p
# corresponding score thresholds for recall steps
scores[th_ind, :, cls_idx, maxDet_idx] = s
return {
'counts': [num_iou_th, num_recall_th, num_classes, num_max_detections], # [4]
'recall': recall, # [num_iou_th, num_classes, num_max_detections]
'precision': precision, # [num_iou_th, num_recall_th, num_classes, num_max_detections]
'scores': scores, # [num_iou_th, num_recall_th, num_classes, num_max_detections]
}
def compute_stats_single_threshold(tp: np.ndarray, fp: np.ndarray, dt_scores_sorted: np.ndarray,
recall_thresholds: Sequence[float], num_gt: int) -> Tuple[
float, np.ndarray, np.ndarray]:
"""
Compute recall value, precision curve and scores thresholds
Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
Args:
tp (np.ndarray): cumsum over true positives [R], R is the number of detections
fp (np.ndarray): cumsum over false positives [R], R is the number of detections
dt_scores_sorted (np.ndarray): sorted (descending) scores [R], R is the number of detections
recall_thresholds (Sequence[float]): recall thresholds which should be evaluated
num_gt (int): number of ground truth bounding boxes (excluding boxes which are ignored)
Returns:
float: overall recall for given IoU value
np.ndarray: precision values at defined recall values
[RTH], where RTH is the number of recall thresholds
np.ndarray: prediction scores corresponding to recall values
[RTH], where RTH is the number of recall thresholds
"""
num_recall_th = len(recall_thresholds)
rc = tp / num_gt
# np.spacing(1) is the smallest representable epsilon with float
pr = tp / (fp + tp + np.spacing(1))
if len(tp):
recall = rc[-1]
else:
# no prediction
recall = 0
# array where precision values nearest to given recall th are saved
precision = np.zeros((num_recall_th,))
# save scores for corresponding recall value in here
th_scores = np.zeros((num_recall_th,))
# numpy is slow without cython optimization for accessing elements
# use python array gets significant speed improvement
pr = pr.tolist(); precision = precision.tolist()
# smooth precision curve (create box shape)
for i in range(len(tp) - 1, 0, -1):
if pr[i] > pr[i-1]:
pr[i-1] = pr[i]
# get indices to nearest given recall threshold (nn interpolation!)
inds = np.searchsorted(rc, recall_thresholds, side='left')
try:
for save_idx, array_index in enumerate(inds):
precision[save_idx] = pr[array_index]
th_scores[save_idx] = dt_scores_sorted[array_index]
except:
pass
return recall, np.array(precision), np.array(th_scores)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import time
import numpy as np
from loguru import logger
from typing import Sequence, List, Dict, Optional, Union, Tuple
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from nndet.evaluator import DetectionMetric
from sklearn.metrics import roc_curve
from collections import defaultdict
class FROCMetric(DetectionMetric):
def __init__(self,
classes: Sequence[str],
iou_thresholds: Sequence[float] = (0.1, 0.5),
fpi_thresholds: Sequence[float] = (1/8, 1/4, 1/2, 1, 2, 4, 8),
per_class: bool = False, verbose: bool = True,
save_dir: Optional[Union[str, Path]] = None,
):
"""
Class to compute FROC
Args:
classes: name of each class
(index needs to correspond to predicted class indices!)
iou_thresholds: IoU thresholds for which FROC
is evaluated
fpi_thresholds: false positive per image
thresholds (curve is interpolated at these values, score is
the mean of the computed sens values at these positions)
per_class: additional FROC curves are computed per class
verbose: log time needed for evaluation
"""
self.classes = classes
self.iou_thresholds = iou_thresholds
self.fpi_thresholds = fpi_thresholds
self.per_class = per_class
self.verbose = verbose
if save_dir is None:
self.save_dir = save_dir
else:
self.save_dir = Path(save_dir)
def get_iou_thresholds(self) -> Sequence[float]:
"""
Return IoU thresholds needed for this metric in an numpy array
Returns:
Sequence[float]: IoU thresholds [M], M is the number of thresholds
"""
return self.iou_thresholds
def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple[
Dict[str, float], Dict[str, np.ndarray]]:
"""
Compute FROC
Args:
results_list: list with result s per image (in list)
per category (dict). Inner Dict contains multiple results
obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where
T = number of thresholds, G = number of ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
Returns:
Dict[str, float]: FROC score per IoU (key: FROC_score@IoU:{key:2f})
Dict[str, np.ndarray]: FROC curve computed at specified fps
thresholds per IoU; [R] R is the number of fps thresholds
(key: FROC_curve@IoU:{key:2f})
"""
if self.verbose:
logger.info('Start FROC metric computation...')
tic = time.time()
scores = {}
curves = {}
_score, _curve = self.compute_froc_mul_iou(results_list)
scores.update(_score)
curves.update(_curve)
if self.verbose:
toc = time.time()
logger.info(f'FROC finished (t={(toc - tic):0.2f}s).')
if self.per_class:
_score, _curve = self.compute_froc_mul_iou_per_class(results_list)
scores.update(_score)
curves.update(_curve)
if self.verbose:
toc = time.time()
logger.info(f'FROC per class finished (t={(toc - tic):0.2f}s).')
if self.save_dir is not None:
self.plot_froc_curves(curves)
return scores, curves
def compute_froc_mul_iou(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple[
Dict[str, float], Dict[str, np.ndarray]]:
"""
Compute FROC curve for multiple IoU values
Args:
results_list: list with result s per image (in list)
per category (dict). Inner Dict contains multiple results
obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, G], where T = number of
thresholds, G = number of ground truth
`gtMatches`: matched ground truth boxes [T, D], where
T = number of thresholds, D = number of detections
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
Returns:
Dict[str, float]: FROC score per IoU
Dict[str,np.ndarray]: FROC curve computed at specified fps
thresholds per IoU; [R] R is the number of fps thresholds
"""
num_images = len(results_list)
results = [_r for r in results_list for _r in r.values()]
if len(results) == 0:
logger.warning(f"WARNING, no results found for froc computation")
return ({"froc_score": 0},
{"froc_curve": np.zeros(len(self.fpi_thresholds))})
# r['dtMatches'] [T, R], where R = sum(all detections)
dt_matches = np.concatenate([r['dtMatches'] for r in results], axis=1)
dt_ignores = np.concatenate([r['dtIgnore'] for r in results], axis=1)
dt_scores = np.concatenate([r['dtScores'] for r in results])
gt_ignore = np.concatenate([r['gtIgnore'] for r in results])
self.check_number_of_iou(dt_matches, dt_ignores)
num_gt = np.count_nonzero(gt_ignore == 0) # number of ground truth boxes (non ignored)
if num_gt == 0:
logger.error("No ground truth found! Returning 0 in FROC.")
return ({"froc_score": 0},
{"froc_curve": np.zeros(len(self.fpi_thresholds))})
# keep shape in case of 1 threshold
old_shape = dt_matches.shape
dt_matches = dt_matches[np.logical_not(dt_ignores)].reshape(old_shape)
curves = {}
for iou_idx, iou_val in enumerate(self.iou_thresholds):
# filter scores with ignores detections
_scores = dt_scores[np.logical_not(dt_ignores[iou_idx])]
assert len(_scores) == len(dt_matches[iou_idx])
_fps, _sens, _th = (self.compute_froc_curve_one_iou(
dt_matches[iou_idx], _scores, num_images, num_gt))
# interpolate at defined fpr thresholds
curves[iou_val] = np.interp(self.fpi_thresholds, _fps, _sens)
# linearly interpolate curves for needed fps values
scores = {f"FROC_score_IoU_{key:.2f}": np.mean(c) for key, c in curves.items()}
curves = {f"FROC_curve_IoU_{key:.2f}": c for key, c in curves.items()}
curves["FROC_fpi_thresholds"] = self.fpi_thresholds
return scores, curves
@staticmethod
def compute_froc_curve_one_iou(dt_matches: np.ndarray, dt_scores: np.ndarray,
num_images: int, num_gt: int):
"""
Compute FROC curve for a single IoU value
Args:
dt_matches (np.ndarray): binary array indicating which bounding
boxes have a large enough overlap with gt;
[R] where R is the number of predictions
dt_scores (np.ndarray): prediction score for each bounding box;
[R] where R is the number of predictions
num_images (int): number of images
num_gt (int): number of ground truth bounding boxes
Returns:
np.ndarray: false positives per image
np.ndarray: sensitivity
np.ndarray: thresholds
"""
num_detections = len(dt_matches)
num_matched = np.sum(dt_matches)
num_unmatched = num_detections - num_matched
if dt_matches.size == 0:
logger.warning("WARNING, no matches found.")
return np.zeros((2,)), np.zeros((2,)), np.zeros((2,))
else:
fpr, tpr, thresholds = roc_curve(dt_matches, dt_scores)
if num_unmatched == 0:
logger.warning("WARNING, no false positives found")
fps = np.zeros(len(fpr))
else:
fps = (fpr * num_unmatched) / num_images
sens = (tpr * num_matched) / num_gt
return fps, sens, thresholds
def compute_froc_mul_iou_per_class(
self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> (
Dict[str, float], Dict[str, np.ndarray]):
"""
Compute FROC curve for multiple classes
Args:
results_list: list with result s per image (in list)
per category (dict). Inner Dict contains multiple results
obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, G], where T = number of
thresholds, G = number of ground truth
`gtMatches`: matched ground truth boxes [T, D], where
T = number of thresholds, D = number of detections
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
Returns:
Dict[str, float]: FROC score computed per class per class
Dict[str, np.ndarray]: FROC curve computed per class per IoU;
[R] R is the number of fps thresholds
"""
froc_scores_cls = {}
froc_curves_cls = {}
for cls_idx, cls_str in enumerate(self.classes):
# filter current class from list of results and put them into a dict with a single entry
results_by_cls = [{0: r[cls_idx]} for r in results_list if cls_idx in r if cls_idx in r]
if results_by_cls:
cls_scores, cls_curves = self.compute_froc_mul_iou(results_by_cls)
froc_scores_cls.update({f"{cls_str}_{key}": item for key, item in cls_scores.items()})
froc_curves_cls.update({f"{cls_str}_{key}": item for key, item in cls_curves.items()})
return froc_scores_cls, froc_curves_cls
def plot_froc_curves(self, curves: Dict[str, Sequence[float]]) -> None:
"""
Plot frocs
Args:
curves: dict with froc curves (as obtained by :method:`compute`)
FROC_score_IoU_{key:.2f} for class "normal" FROC
{cls_name}_FROC_score_IoU_{key:.2f}: for class specific froc
"""
# plot normal froc curves
selection = select_froc_curves(curves)
fig, ax = get_froc_ax(self.fpi_thresholds)
for _, froc, iou in zip(*selection):
ax.plot(self.fpi_thresholds, froc, 'o-', label=f"IoU:{iou:.2f}")
ax.set_title("FROC")
ax.legend(loc='lower right')
fig.savefig(self.save_dir / "FROC.png")
plt.close(fig)
# plot cls frocs
selection = select_froc_curves_cls(curves)
reordered = defaultdict(list)
for class_name, (names, frocs, ious) in selection.items():
for froc, iou in zip(frocs, ious):
reordered[iou].append((class_name, froc))
for iou, frocs in reordered.items():
fig, ax = get_froc_ax(self.fpi_thresholds)
for class_name, froc in frocs:
ax.plot(self.fpi_thresholds, froc, 'o-', label=f"{class_name}")
title = f"FROC_cls_IoU_{iou:.2f}"
ax.set_title(title)
ax.legend(loc='lower right')
fig.savefig(self.save_dir / f"{title.replace('.', '_')}.png")
plt.close(fig)
def get_froc_ax(fpi_values: Optional[Sequence[float]] = None) -> Tuple[plt.Figure, plt.Axes]:
"""
Create preconfigured figure and axes object for froc curves
Args:
fpi_values: x values to use for froc
Returns:
plt.Figure: figure object
plt.Axes: configured axes object
"""
fig, ax = plt.subplots()
ax.set_xscale("log", base=2)
if fpi_values is not None:
ax.set_xlim(min(fpi_values), max(fpi_values))
ax.set_xticks(fpi_values)
ax.set_ylim(0, 1)
ax.set_xlabel('Avg number of false positives per scan')
ax.set_ylabel('Sensitivity')
ax.grid(True)
formatter = FuncFormatter(lambda y, _: '{:.3f}'.format(y))
ax.xaxis.set_major_formatter(formatter)
return fig, ax
def select_froc_curves(curves: Dict[str, np.ndarray], prefix: Optional[str] = None) -> \
Tuple[List[str], List[np.ndarray], List[float]]:
"""
Select froc curves
Args:
curves: dict to select frocs from. Class specific frocs need to
follow FROC_score_IoU_{key:.2f} pattern
Returns:
Dict[str, Tuple[List[str], List[np.ndarray], List[float]]]:
dict defines the classes, tuple is output from
:method:`select_froc_curves_cls`
"""
if prefix is None:
prefix = ""
froc_keys = [str(c) for c in curves.keys()
if str(c).startswith(f"{prefix}FROC_") and
not str(c).endswith("_thresholds")]
frocs = [curves[c] for c in froc_keys]
ious = [float(c.rsplit('_', 1)[1]) for c in froc_keys]
return froc_keys, frocs, ious
def select_froc_curves_cls(curves: Dict[str, np.ndarray]) -> \
Dict[str, Tuple[List[str], List[np.ndarray], List[float]]]:
"""
Select class specific froc curves
Args:
curves: dict to select frocs from. Class specific frocs need to follow
{cls_name}_FROC_score_IoU_{key:.2f} pattern
Returns:
Dict[str, Tuple[List[str], List[np.ndarray], List[float]]]:
dict defines the classes, tuple is output from
:method:`select_froc_curves_cls`
"""
all_classes = [str(c).split('_', 1)[0] for c in curves.keys()
if not str(c).startswith("FROC_") and
not str(c).endswith("_thresholds")]
all_classes = list(set(all_classes))
output = {}
for cls_name in all_classes:
output[cls_name] = select_froc_curves(curves, prefix=f"{cls_name}_")
return output
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import time
import numpy as np
from pathlib import Path
from loguru import logger
from typing import Sequence, List, Dict, Any, Tuple
import matplotlib.pyplot as plt
from nndet.evaluator import DetectionMetric
class PredictionHistogram(DetectionMetric):
def __init__(self,
classes: Sequence[str], save_dir: Path,
iou_thresholds: Sequence[float] = (0.1, 0.5),
bins: int = 50):
"""
Class to compute prediction histograms. (Note: this class does not
provide any scalar metrics)
Args:
classes: name of each class (index needs to correspond to predicted class indices!)
save_dir: directory where histograms are saved to
iou_thresholds: IoU thresholds for which FROC is evaluated
bins: number of bins of histogram
"""
self.classes = classes
self.save_dir = save_dir
self.iou_thresholds = iou_thresholds
self.bins = bins
def get_iou_thresholds(self) -> Sequence[float]:
"""
Return IoU thresholds needed for this metric in an numpy array
Returns:
Sequence[float]: IoU thresholds [M], M is the number of thresholds
"""
return self.iou_thresholds
def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple[
Dict[str, float], Dict[str, Dict[str, Any]]]:
"""
Plot class independent and per class histograms. For more info see
`method``plot_hist`
Args:
Dict: results over dataset
"""
self.plot_hist(results_list=results_list)
for cls_idx, cls_str in enumerate(self.classes):
# filter current class from list of results and put them into a dict with a single entry
results_by_cls = [{0: r[cls_idx]} for r in results_list if cls_idx in r if cls_idx in r]
self.plot_hist(results_by_cls, title_prefix=f"cl_{cls_str}_")
return {}, {}
def plot_hist(self, results_list: List[Dict[int, Dict[str, np.ndarray]]],
title_prefix: str = "") -> Tuple[
Dict[str, float], Dict[str, Dict[str, Any]]]:
"""
Compute prediction histograms for multiple IoU values
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
`dtMatches`: matched detections [T, G], where T = number of thresholds, G = number of ground truth
`gtMatches`: matched ground truth boxes [T, D], where T = number of thresholds,
D = number of detections
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored [G] indicate whether ground truth
should be ignored
`dtIgnore`: detections which should be ignored [T, D], indicate which detections should be ignored
title_prefix: prefix for title of histogram plot
Returns:
Dict: empty
Dict[Dict[str, Any]]: histogram informations
`{IoU Value}`:
`tp_hist` (np.ndarray): histogram if true positives; false negatives @ score=0 [:attr:`self.bins`]
`fp_hist` (np.ndarray): false positive histogram [:attr:`self.bins`]
`true_positives` (int): number of true positives according to matching
`false_positives` (int): number of false_positives according to matching
`false_negatives` (int): number of false_negatives according to matching
"""
num_images = len(results_list)
results = [_r for r in results_list for _r in r.values()]
if len(results) == 0:
logger.warning(f"WARNING, no results found for froc computation")
return {}, {}
# r['dtMatches'] [T, R], where R = sum(all detections)
dt_matches = np.concatenate([r['dtMatches'] for r in results], axis=1)
dt_ignores = np.concatenate([r['dtIgnore'] for r in results], axis=1)
dt_scores = np.concatenate([r['dtScores'] for r in results])
gt_ignore = np.concatenate([r['gtIgnore'] for r in results])
self.check_number_of_iou(dt_matches, dt_ignores)
num_gt = np.count_nonzero(gt_ignore == 0) # number of ground truth boxes (non ignored)
if num_gt == 0:
logger.error("No ground truth found! Returning nothing.")
return {}, {}
for iou_idx, iou_val in enumerate(self.iou_thresholds):
# filter scores with ignores detections
_scores = dt_scores[np.logical_not(dt_ignores[iou_idx])]
assert len(_scores) == len(dt_matches[iou_idx])
_ = self.compute_histogram_one_iou(\
dt_matches[iou_idx], _scores, num_images, num_gt, iou_val, title_prefix)
return {}, {}
def compute_histogram_one_iou(self, dt_matches: np.ndarray, dt_scores: np.ndarray,
num_images: int, num_gt: int, iou: float,
title_prefix: str):
"""
Plot prediction histogram
Args:
dt_matches (np.ndarray): binary array indicating which bounding
boxes have a large enough overlap with gt;
[R] where R is the number of predictions
dt_scores (np.ndarray): prediction score for each bounding box;
[R] where R is the number of predictions
num_images (int): number of images
num_gt (int): number of ground truth bounding boxes
iou: IoU values which is currently evaluated
title_prefix: prefix for title of histogram plot
"""
num_matched = np.sum(dt_matches)
false_negatives = num_gt - num_matched # false negatives
true_positives = np.sum(dt_matches)
false_positives = np.sum(dt_matches == 0)
_dt_matches = np.concatenate([dt_matches, [1] * int(false_negatives)])
_dt_scores = np.concatenate([dt_scores, [0] * int(false_negatives)])
plt.figure()
plt.yscale('log')
if 0 in dt_matches:
plt.hist(_dt_scores[_dt_matches == 0], bins=self.bins, range=(0., 1.),
alpha=0.3, color='g', label='false pos.')
if 1 in dt_matches:
plt.hist(_dt_scores[_dt_matches == 1], bins=self.bins, range=(0., 1.),
alpha=0.3, color='b', label='true pos. (false neg. @ score=0)')
plt.legend()
title = title_prefix + (f"tp:{true_positives} fp:{false_positives} "
f"fn:{false_negatives} pos:{true_positives+false_negatives}")
plt.title(title)
plt.xlabel('confidence score')
plt.ylabel('log n')
if self.save_dir is not None:
save_path = self.save_dir / (f"{title_prefix}pred_hist_IoU@{iou}".replace(".", "_") + ".png")
logger.info(f"Saving {save_path}")
plt.savefig(save_path)
plt.close()
return None
import numpy as np
from typing import Callable, Sequence, List, Dict
__all__ = ["matching_batch"]
def matching_batch(
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
iou_thresholds: Sequence[float], pred_boxes: Sequence[np.ndarray],
pred_classes: Sequence[np.ndarray], pred_scores: Sequence[np.ndarray],
gt_boxes: Sequence[np.ndarray], gt_classes: Sequence[np.ndarray],
gt_ignore: Sequence[Sequence[bool]], max_detections: int = 100,
) -> List[Dict[int, Dict[str, np.ndarray]]]:
"""
Match boxes of a batch to corresponding ground truth for each category
independently
Args:
iou_fn: compute overlap for each pair
iou_thresholds: defined which IoU thresholds should be evaluated
pred_boxes: predicted boxes from single batch; List[[D, dim * 2]],
D number of predictions
pred_classes: predicted classes from a single batch; List[[D]],
D number of predictions
pred_scores: predicted score for each bounding box; List[[D]],
D number of predictions
gt_boxes: ground truth boxes; List[[G, dim * 2]], G number of ground
truth
gt_classes: ground truth classes; List[[G]], G number of ground truth
gt_ignore: specified if which ground truth boxes are not counted as
true positives
(detections which match theses boxes are not counted as false
positives either); List[[G]], G number of ground truth
max_detections: maximum number of detections which should be evaluated
Returns:
List[Dict[int, Dict[str, np.ndarray]]]
matched detections [dtMatches] and ground truth [gtMatches]
boxes [str, np.ndarray] for each category (stored in dict keys)
for each image (list)
"""
results = []
# iterate over images/batches
for pboxes, pclasses, pscores, gboxes, gclasses, gignore in zip(
pred_boxes, pred_classes, pred_scores, gt_boxes, gt_classes, gt_ignore):
img_classes = np.union1d(pclasses, gclasses)
result = {} # dict contains results for each class in one image
for c in img_classes:
pred_mask = pclasses == c # mask predictions with current class
gt_mask = gclasses == c # mask ground trtuh with current class
if not np.any(gt_mask): # no ground truth
result[c] = _matching_no_gt(
iou_thresholds=iou_thresholds,
pred_scores=pscores[pred_mask],
max_detections=max_detections)
elif not np.any(pred_mask): # no predictions
result[c] = _matching_no_pred(
iou_thresholds=iou_thresholds,
gt_ignore=gignore[gt_mask],
)
else: # at least one prediction and one ground truth
result[c] = _matching_single_image_single_class(
iou_fn=iou_fn,
pred_boxes=pboxes[pred_mask],
pred_scores=pscores[pred_mask],
gt_boxes=gboxes[gt_mask],
gt_ignore=gignore[gt_mask],
max_detections=max_detections,
iou_thresholds=iou_thresholds,
)
results.append(result)
return results
def _matching_no_gt(
iou_thresholds: Sequence[float],
pred_scores: np.ndarray,
max_detections: int,
):
"""
Matching result with not ground truth in image
Args:
iou_thresholds: defined which IoU thresholds should be evaluated
dt_scores: predicted scores
max_detections: maximum number of allowed detections per image.
This functions uses this parameter to stay consistent with
the actual matching function which needs this limit.
Returns:
dict: computed matching
`dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number
of thresholds, G = number of ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
"""
dt_ind = np.argsort(-pred_scores, kind='mergesort')
dt_ind = dt_ind[:max_detections]
dt_scores = pred_scores[dt_ind]
num_preds = len(dt_scores)
gt_match = np.array([[]] * len(iou_thresholds))
dt_match = np.zeros((len(iou_thresholds), num_preds))
dt_ignore = np.zeros((len(iou_thresholds), num_preds))
return {
'dtMatches': dt_match, # [T, D], where T = number of thresholds, D = number of detections
'gtMatches': gt_match, # [T, G], where T = number of thresholds, G = number of ground truth
'dtScores': dt_scores, # [D] detection scores
'gtIgnore': np.array([]).reshape(-1), # [G] indicate whether ground truth should be ignored
'dtIgnore': dt_ignore, # [T, D], indicate which detections should be ignored
}
def _matching_no_pred(
iou_thresholds: Sequence[float],
gt_ignore: np.ndarray,
):
"""
Matching result with no predictions
Args:
iou_thresholds: defined which IoU thresholds should be evaluated
gt_ignore: specified if which ground truth boxes are not counted as
true positives (detections which match theses boxes are not
counted as false positives either); [G], G number of ground truth
Returns:
dict: computed matching
`dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number
of thresholds, G = number of ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
"""
dt_scores = np.array([])
dt_match = np.array([[]] * len(iou_thresholds))
dt_ignore = np.array([[]] * len(iou_thresholds))
gt_match = np.zeros((len(iou_thresholds), len(gt_ignore)))
return {
'dtMatches': dt_match, # [T, D], where T = number of thresholds, D = number of detections
'gtMatches': gt_match, # [T, G], where T = number of thresholds, G = number of ground truth
'dtScores': dt_scores, # [D] detection scores
'gtIgnore': gt_ignore.reshape(-1), # [G] indicate whether ground truth should be ignored
'dtIgnore': dt_ignore, # [T, D], indicate which detections should be ignored
}
def _matching_single_image_single_class(
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
pred_boxes: np.ndarray,
pred_scores: np.ndarray,
gt_boxes: np.ndarray,
gt_ignore: np.ndarray,
max_detections: int,
iou_thresholds: Sequence[float],
) -> Dict[str, np.ndarray]:
"""
Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
Args:
iou_fn: compute overlap for each pair
iou_thresholds: defined which IoU thresholds should be evaluated
pred_boxes: predicted boxes from single batch; [D, dim * 2], D number
of predictions
pred_scores: predicted score for each bounding box; [D], D number of
predictions
gt_boxes: ground truth boxes; [G, dim * 2], G number of ground truth
gt_ignore: specified if which ground truth boxes are not counted as
true positives (detections which match theses boxes are not
counted as false positives either); [G], G number of ground truth
max_detections: maximum number of detections which should be evaluated
Returns:
dict: computed matching
`dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
`gtMatches`: matched ground truth boxes [T, G], where T = number
of thresholds, G = number of ground truth
`dtScores`: prediction scores [D] detection scores
`gtIgnore`: ground truth boxes which should be ignored
[G] indicate whether ground truth should be ignored
`dtIgnore`: detections which should be ignored [T, D],
indicate which detections should be ignored
"""
# filter for max_detections highest scoring predictions to speed up computation
dt_ind = np.argsort(-pred_scores, kind='mergesort')
dt_ind = dt_ind[:max_detections]
pred_boxes = pred_boxes[dt_ind]
pred_scores = pred_scores[dt_ind]
# sort ignored ground truth to last positions
gt_ind = np.argsort(gt_ignore, kind='mergesort')
gt_boxes = gt_boxes[gt_ind]
gt_ignore = gt_ignore[gt_ind]
# ious between sorted(!) predictions and ground truth
ious = iou_fn(pred_boxes, gt_boxes)
num_preds, num_gts = ious.shape[0], ious.shape[1]
gt_match = np.zeros((len(iou_thresholds), num_gts))
dt_match = np.zeros((len(iou_thresholds), num_preds))
dt_ignore = np.zeros((len(iou_thresholds), num_preds))
for tind, t in enumerate(iou_thresholds):
for dind, _d in enumerate(pred_boxes): # iterate detections starting from highest scoring one
# information about best match so far (m=-1 -> unmatched)
iou = min([t, 1-1e-10])
m = -1
for gind, _g in enumerate(gt_boxes): # iterate ground truth
# if this gt already matched, continue
if gt_match[tind, gind] > 0:
continue
# if dt matched to reg gt, and on ignore gt, stop
if m > -1 and gt_ignore[m] == 0 and gt_ignore[gind] == 1:
break
# continue to next gt unless better match made
if ious[dind, gind] < iou:
continue
# if match successful and best so far, store appropriately
iou = ious[dind, gind]
m = gind
# if match made, store id of match for both dt and gt
if m == -1:
continue
else:
dt_ignore[tind, dind] = int(gt_ignore[m])
dt_match[tind, dind] = 1
gt_match[tind, m] = 1
# store results for given image and category
return {
'dtMatches': dt_match, # [T, D], where T = number of thresholds, D = number of detections
'gtMatches': gt_match, # [T, G], where T = number of thresholds, G = number of ground truth
'dtScores': pred_scores, # [D] detection scores
'gtIgnore': gt_ignore.reshape(-1), # [G] indicate whether ground truth should be ignored
'dtIgnore': dt_ignore, # [T, D], indicate which detections should be ignored
}
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from os import PathLike
from pathlib import Path
from typing import Dict, Sequence, Optional, Tuple
import numpy as np
from loguru import logger
from nndet.io.load import load_pickle, save_json, save_pickle
from nndet.evaluator.det import BoxEvaluator
from nndet.evaluator.case import CaseEvaluator
from nndet.evaluator.seg import PerCaseSegmentationEvaluator
def save_metric_output(scores, curves, base_dir, name):
"""
Helper function to save output of the function in a nice format
"""
scores_string = {str(key): str(item) for key, item in scores.items()}
save_json(scores_string, base_dir / f"{name}.json")
save_pickle({"scores": scores, "curves": curves}, base_dir / f"{name}.pkl")
def evaluate_box_dir(
pred_dir: PathLike,
gt_dir: PathLike,
classes: Sequence[str],
save_dir: Optional[Path] = None,
) -> Tuple[Dict, Dict]:
"""
Run box evaluation inside a directory
Args:
pred_dir: path to dir with predictions
gt_dir: path to dir with groud truth data
classes: classes present in dataset
save_dir: optional path to save plots
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
See Also:
:class:`nndet.evaluator.registry.BoxEvaluator`
"""
pred_dir = Path(pred_dir)
gt_dir = Path(gt_dir)
if save_dir is not None:
save_dir.mkdir(parents=True, exist_ok=True)
case_ids = [p.stem.rsplit('_boxes', 1)[0] for p in pred_dir.iterdir()
if p.is_file() and p.stem.endswith("_boxes")]
logger.info(f"Found {len(case_ids)} for box evaluation in {pred_dir}")
evaluator = BoxEvaluator.create(classes=classes,
fast=False,
verbose=False,
save_dir=save_dir,
)
for case_id in case_ids:
gt = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
pred = load_pickle(pred_dir / f"{case_id}_boxes.pkl")
evaluator.run_online_evaluation(
pred_boxes=[pred["pred_boxes"]], pred_classes=[pred["pred_labels"]],
pred_scores=[pred["pred_scores"]], gt_boxes=[gt["boxes"]],
gt_classes=[gt["classes"]], gt_ignore=None,
)
return evaluator.finish_online_evaluation()
def evaluate_case_dir(
pred_dir: PathLike,
gt_dir: PathLike,
classes: Sequence[str],
target_class: Optional[int] = None,
) -> Tuple[Dict, Dict]:
"""
Run evaluation of case results inside a directory
Args:
pred_dir: path to dir with predictions
gt_dir: path to dir with groud truth data
classes: classes present in dataset
target_class in case of multiple classes, specify a target class
to evaluate in a target class vs rest setting
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graph)
See Also:
:class:`nndet.evaluator.registry.CaseEvaluator`
"""
pred_dir = Path(pred_dir)
gt_dir = Path(gt_dir)
case_ids = [p.stem.rsplit('_boxes', 1)[0] for p in pred_dir.iterdir()
if p.is_file() and p.stem.endswith("_boxes")]
logger.info(f"Found {len(case_ids)} for case evaluation in {pred_dir}")
evaluator = CaseEvaluator.create(classes=classes,
target_class=target_class,
)
for case_id in case_ids:
gt = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
pred = load_pickle(pred_dir / f"{case_id}_boxes.pkl")
evaluator.run_online_evaluation(
pred_classes=[pred["pred_labels"]],
pred_scores=[pred["pred_scores"]],
gt_classes=[gt["classes"]]
)
return evaluator.finish_online_evaluation()
def evaluate_seg_dir(
pred_dir: PathLike,
gt_dir: PathLike,
classes: Sequence[str],
) -> Tuple[Dict, None]:
"""
Compute dice metric across a directory
Args:
pred_dir: path to dir with predictions
gt_dir: path to dir with groud truth data
classes: classes present in dataset
Returns:
Dict[str, float]: dictionary with scalar values for evaluation
None
See Also:
:class:`nndet.evaluator.registry.PerCaseSegmentationEvaluator`
"""
pred_dir = Path(pred_dir)
gt_dir = Path(gt_dir)
case_ids = [p.stem.rsplit('_seg', 1)[0] for p in pred_dir.iterdir()
if p.is_file() and p.stem.endswith("_seg")]
logger.info(f"Found {len(case_ids)} for seg evaluation in {pred_dir}")
evaluator = PerCaseSegmentationEvaluator.create(classes=classes)
for case_id in case_ids:
gt = np.load(str(gt_dir / f"{case_id}_seg_gt.npz"), allow_pickle=True)["seg"] # 1, dims
pred = load_pickle(pred_dir / f"{case_id}_seg.pkl")
evaluator.run_online_evaluation(
seg=pred[None],
target=gt,
)
return evaluator.finish_online_evaluation()
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from typing import Dict, Sequence, Tuple
from collections import defaultdict
from nndet.evaluator import AbstractEvaluator
__all__ = ["SegmentationEvaluator"]
class SegmentationEvaluator(AbstractEvaluator):
def __init__(self,
per_class: bool = True,
*args,
**kwargs,
):
"""
Compute dice score during training
"""
self.per_class = per_class
self.results_list = defaultdict(list)
def reset(self):
"""
Reset internal state for new epoch
"""
self.results_list = defaultdict(list)
def run_online_evaluation(self,
seg_probs: np.ndarray,
target: np.ndarray,
) -> Dict:
"""
Run evaluation of one batch and save internal results for later
Args:
seg_probs: output probabilities of network [N, C, dims], where N
is the batch size, C is the number of classes, dims are
spatial dimensions
target: ground truth segmentation [N, dims], where N is the batch
size and dims are spatial dimensions
Returns:
Dict: empty dict
"""
num_classes = seg_probs.shape[1]
output_seg = np.argmax(seg_probs, axis=1).reshape((seg_probs.shape[0], -1))
target = target.reshape((target.shape[0], -1))
tp_hard = np.zeros((target.shape[0], num_classes - 1))
fp_hard = np.zeros((target.shape[0], num_classes - 1))
fn_hard = np.zeros((target.shape[0], num_classes - 1))
for c in range(1, num_classes):
tp_hard[:, c - 1] = ((output_seg == c).astype(np.float32) * (target == c).astype(np.float32)).sum(axis=1)
fp_hard[:, c - 1] = ((output_seg == c).astype(np.float32) * (target != c).astype(np.float32)).sum(axis=1)
fn_hard[:, c - 1] = ((output_seg != c).astype(np.float32) * (target == c).astype(np.float32)).sum(axis=1)
tp_hard = tp_hard.sum(axis=0)
fp_hard = fp_hard.sum(axis=0)
fn_hard = fn_hard.sum(axis=0)
self.results_list["fg_dice"] = list(
(2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))
self.results_list["tp"].append(tp_hard)
self.results_list["fp"].append(fp_hard)
self.results_list["fn"].append(fn_hard)
return {}
def finish_online_evaluation(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Summarize results from batches and compute global dice and global
dice per class
Returns:
Dict: results
`{cls_idx}_seg_dice`: global dice per class
`seg_dice`: global dice over all classes
"""
results = {}
if self.results_list:
tp = np.sum(self.results_list["tp"], 0)
fp = np.sum(self.results_list["fp"], 0)
fn = np.sum(self.results_list["fn"], 0)
global_dc_per_class = [
i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(tp, fp, fn)] if not np.isnan(i)]
if self.per_class:
for cls_idx, dc in enumerate(global_dc_per_class):
results[f"{cls_idx}_seg_dice"] = dc
results["seg_dice"] = np.mean(global_dc_per_class)
return results, None
@classmethod
def create(cls,
per_class: bool = False,
):
return cls(per_class=per_class)
class PerCaseSegmentationEvaluator(AbstractEvaluator):
def __init__(self,
classes: Sequence[str],
*args,
**kwargs,
):
"""
Compute dice score per case and average results over dataset
"""
self.classes = classes
self.results = []
def reset(self):
"""
Reset internal state for new epoch
"""
self.results = []
def run_online_evaluation(self,
seg: np.ndarray,
target: np.ndarray,
) -> Dict:
"""
Run evaluation of one batch and save internal results for later
Args:
seg: output segmentation [N, dims]
target: ground truth segmentation [N, dims], where N is the batch
size and dims are spatial dimensions
Returns:
Dict: empty dict
"""
assert len(seg) == len(target)
num_classes = len(self.classes)
output_seg = seg.reshape((seg.shape[0], -1)) # N, X
target = target.reshape((target.shape[0], -1)) # N, X
tp_hard = np.zeros((target.shape[0], num_classes - 1)) # N, FG
fp_hard = np.zeros((target.shape[0], num_classes - 1)) # N ,FG
fn_hard = np.zeros((target.shape[0], num_classes - 1)) # N, FG
fg_present = np.zeros((target.shape[0], num_classes - 1)) # N, FG
for c in range(1, num_classes):
tp_hard[:, c - 1] = ((output_seg == c).astype(np.float32) * (target == c).astype(np.float32)).sum(axis=1)
fp_hard[:, c - 1] = ((output_seg == c).astype(np.float32) * (target != c).astype(np.float32)).sum(axis=1)
fn_hard[:, c - 1] = ((output_seg != c).astype(np.float32) * (target == c).astype(np.float32)).sum(axis=1)
fg_present[:, c - 1] = (target == c).any(axis=1).astype(np.int32)
dice = np.where(fg_present, 2. * tp_hard / (2 * tp_hard + fp_hard + fn_hard), np.nan) # N, FG
self.results.append(dice)
return {}
def finish_online_evaluation(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]:
"""
Summarize results from batches and compute global dice and global
dice per class
Returns:
Dict: results
`{cls_idx}_seg_dice`: global dice per class
`seg_dice`: global dice over all classes
"""
dice_full = np.concatenate(self.results, axis=0)
dice_per_class = dice_full.mean(axies=0) # C
dice = dice_full.mean() # 1
results = {}
for cls_idx, value in enumerate(dice_per_class):
results[f"dice_cls_{cls_idx}"] = float(value)
results["dice"] = float(dice)
return results, None
@classmethod
def create(cls,
classes: Sequence[str],
):
return cls(classes=classes)
from nndet.inference.ensembler import BaseEnsemblerType, BaseEnsembler, BoxEnsembler, SegmentationEnsembler
from nndet.inference.predictor import PredictorType, Predictor
from nndet.inference.sweeper import SweeperType, Sweeper, BoxSweeper
from nndet.inference.restore import restore_detection, restore_fmap
from nndet.inference.detection.wbc import batched_wbc, wbc
from nndet.inference.detection.model import batched_nms_model
from nndet.inference.detection.ensemble import batched_wbc_ensemble, batched_nms_ensemble, \
wbc_nms_no_label_ensemble
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple
from torch import Tensor
from nndet.core.boxes import batched_nms, nms
from nndet.inference.detection import batched_wbc
def batched_nms_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble nms for ensembler (same as batched nms with adjusted signature)
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box (ignored in this function)
iou_thresh: IoU threshold for nms
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: boxes
Tensor: scores
Tensor: labels
"""
keep = batched_nms(boxes=boxes, scores=scores,
idxs=labels, iou_threshold=iou_thresh,
)
return boxes[keep], scores[keep], labels[keep]
def batched_wbc_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
*args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble wbc for ensembler (same as batched nms with adjusted signature)
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box (ignored in this function)
iou_thresh: IoU threshold for nms
n_exp_preds: number of expected predictions per box
score_thresh: minimum score
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: boxes
Tensor: scores
Tensor: labels
"""
boxes, scores, labels = batched_wbc(
boxes, scores, labels, weights=weights,
n_exp_preds=n_exp_preds,
iou_thresh=iou_thresh,
score_thresh=score_thresh,
)
return boxes, scores, labels
def wbc_nms_no_label_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
*args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""
Normal wbc -> nms without class labels
This results in a single prediction per position regardless of the class
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box (ignored in this function)
iou_thresh: IoU threshold for nms
n_exp_preds: number of expected predictions per box
score_thresh: minimum score
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: boxes
Tensor: scores
Tensor: labels
"""
boxes, scores, labels = batched_wbc(
boxes, scores, labels, weights=weights,
n_exp_preds=n_exp_preds,
iou_thresh=iou_thresh,
score_thresh=score_thresh,
)
keep = nms(boxes, scores, iou_thresh)
return boxes[keep], scores[keep], labels[keep]
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple
from torch import Tensor
import torch
from nndet.core.boxes import batched_nms
def batched_nms_model(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Model nms for ensembler (same as batched nms with adjusted signature)
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box
iou_thresh: IoU threshold for nms
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: sorted boxes
Tensor: sorted scores (descending)
Tensor: sorted labels
Tensor: sorted weights
"""
keep = batched_nms(boxes=boxes, scores=scores,
idxs=labels, iou_threshold=iou_thresh,
)
return boxes[keep], scores[keep], labels[keep], weights[keep]
def batched_weighted_nms_model(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Model nms for ensembler (same as batched nms with adjusted signature)
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box
iou_thresh: IoU threshold for nms
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: sorted boxes
Tensor: sorted scores (descending)
Tensor: sorted labels
Tensor: sorted weights
"""
new_scores = scores * weights
keep = batched_nms(boxes=boxes, scores=new_scores, idxs=labels, iou_threshold=iou_thresh)
new_weights = torch.ones_like(weights)
return boxes[keep], scores[keep], labels[keep], new_weights[keep]
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple
import torch
from torch import Tensor
from nndet.core.boxes import batched_nms, nms
from nndet.inference.detection import batched_wbc
def batched_nms_model(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
keep = batched_nms(boxes=boxes, scores=scores,
idxs=labels, iou_threshold=iou_thresh,
)
return boxes[keep], scores[keep], labels[keep], weights[keep]
def batched_nms_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor]:
keep = batched_nms(boxes=boxes, scores=scores,
idxs=labels, iou_threshold=iou_thresh,
)
return boxes[keep], scores[keep], labels[keep]
def batched_wbc_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
*args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
boxes, scores, labels = batched_wbc(
boxes, scores, labels, weights=weights,
n_exp_preds=n_exp_preds,
iou_thresh=iou_thresh,
score_thresh=score_thresh,
)
return boxes, scores, labels
def wbc_nms_no_label_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
*args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
boxes, scores, labels = batched_wbc(
boxes, scores, labels, weights=weights,
n_exp_preds=n_exp_preds,
iou_thresh=iou_thresh,
score_thresh=score_thresh,
)
keep = nms(boxes, scores, iou_thresh)
return boxes[keep], scores[keep], labels[keep]
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch import Tensor
from typing import Tuple
from torch._C import device
from nndet.core.boxes import box_iou, box_area
__all__ = ["batched_wbc", "wbc"]
def batched_wbc(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
use_area: bool = False,
missing_weight: float = 1.,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Computed weighted box clustering per class
Args:
boxes: predicted boxes (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
scores: predicted scores [N]
labels: predicted labels [N]
weights: weight for each box [N] (gaussian weighting of boxes near
corners need to be included in this weight)
iou_thresh: iou threshold used for clustering boxes
n_exp_preds: number of expected predictions per box (computed as the
mean number predictions inside the bounding box)
score_thresh: minimum score of predictions after clustering
use_area: assigns higher weights to larger boxes based on
empirical observations indicating an increase in image
evidence from larger areas.
missing_weight: weight for score dampening when predictions are missing
Returns:
Tensor: clustered boxes
Tensor: clustered scores
Tensor: labels
"""
clustered_boxes = []
clustered_scores = []
clustered_labels = []
for label in labels.unique():
_labels_mask = labels == label
_boxes = boxes[_labels_mask]
_scores = scores[_labels_mask]
_weights = weights[_labels_mask]
_n_exp_preds = n_exp_preds[_labels_mask]
b, s = wbc(_boxes, _scores,
weights=_weights, n_exp_preds=_n_exp_preds,
iou_thresh=iou_thresh, score_thresh=score_thresh,
use_area=use_area,
missing_weight=missing_weight,
)
clustered_boxes.append(b)
clustered_scores.append(s)
clustered_labels.append(torch.empty_like(s).fill_(label))
if clustered_boxes:
return (torch.cat(clustered_boxes, dim=0),
torch.cat(clustered_scores, dim=0),
torch.cat(clustered_labels, dim=0))
else:
return (torch.tensor([]).view(-1, boxes.shape[1]),
torch.tensor([]).view(-1),
torch.tensor([]).view(-1))
def wbc(
boxes: Tensor,
scores: Tensor,
weights: Tensor,
n_exp_preds: Tensor,
iou_thresh: float,
score_thresh: float,
use_area: bool = True,
missing_weight: float = 1.,
) -> Tuple[Tensor, Tensor]:
"""
Weighted box clustering
Args:
boxes: tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
scores: score for each box [N]
weights: additional weights for boxes [N]
n_exp_preds: expected number of predictions per box
iou_thresh: iou threshold for determining clusters of boxes which are
combined
score_thresh: minimum scores of boxes after consolidation
use_area: assigns higher weights to larger boxes based on
empirical observations indicating an increase in image
evidence from larger areas.
missing_weight: weight for score dampening when predictions are missing
Returns:
Tensor: consolidated boxes
Tensor: consolidated scores
"""
ious = box_iou(boxes, boxes)
if use_area:
areas = box_area(boxes)
weights = weights * areas
_, idx_pool = torch.sort(scores, descending=True)
new_boxes, new_scores = [], []
while idx_pool.nelement() > 0:
# build cluster
highest_scoring_id = idx_pool[0]
matches = torch.where(ious[highest_scoring_id][idx_pool] > iou_thresh)[0].flatten()
box_idx = idx_pool[matches]
# compute new scores
n_expected = n_exp_preds[box_idx].float().mean()
new_box, new_score = compute_cluster_consolidation(
boxes[box_idx], scores[box_idx],
weights=weights[box_idx],
ious=ious[highest_scoring_id][box_idx],
n_expected=n_expected,
n_found=len(box_idx),
missing_weight=missing_weight,
)
if new_score > score_thresh:
new_boxes.append(new_box)
new_scores.append(new_score)
# get all elements that were not matched and discard all others.
non_matches = torch.where(ious[highest_scoring_id][idx_pool] <= iou_thresh)[0].flatten()
idx_pool = idx_pool[non_matches]
if new_boxes:
return torch.stack(new_boxes, dim=0), torch.cat(new_scores, dim=0)
else:
return torch.tensor([]).view(-1, boxes.shape[1]).to(boxes), torch.tensor([]).view(-1).to(scores)
def compute_cluster_consolidation(
boxes: Tensor,
scores: Tensor,
weights: Tensor,
ious: Tensor,
n_expected: Tensor,
n_found: int,
missing_weight: float,
) -> Tuple[Tensor, Tensor]:
"""
Consolidate predictions of a single cluster
Args:
boxes: boxes of a single cluster (x1, y1, x2, y2, (z1, z2) [N, dims * 2]
scores: scores of a single cluster [N]
weights: weights for boxes of a single cluster [N]
ious: ious with recard to highest scoring box in a single cluster [N]
n_expected: expected number of predictions
n_found: number of predictions
missing_weight: weight for score dampening when predictions are missing
Returns:
Tensor: new boxes (x1, y1, x2, y2, (z1, z2) [N, dims * 2]
Tensor: new scores [N]
"""
# compute new score
match_score_weights = ious * weights
match_scores = match_score_weights * scores
n_missing_preds = torch.max(torch.tensor([0.], device=n_expected.device),
(n_expected - n_found).float())
denom = match_score_weights.sum() + n_missing_preds * match_score_weights.mean() * missing_weight
consolidated_score = match_scores.sum() / denom
consolidated_boxes = (boxes * match_scores.reshape(-1, 1)).sum(dim=0) / match_scores.sum()
return consolidated_boxes, consolidated_score
def compute_cluster_consolidation2(
boxes: Tensor,
scores: Tensor,
weights: Tensor,
ious: Tensor,
n_expected: Tensor,
n_found: int,
missing_weight: float,
) -> Tuple[Tensor, Tensor]:
"""
Consolidate predictions of a single cluster
Args:
boxes: boxes of a single cluster (x1, y1, x2, y2, (z1, z2) [N, dims * 2]
scores: scores of a single cluster [N]
weights: weights for boxes of a single cluster [N]
ious: ious with recard to highest scoring box in a single cluster [N]
n_expected: expected number of predictions
n_found: number of predictions
missing_weight: weight for score dampening when predictions are missing
Returns:
Tensor: new boxes (x1, y1, x2, y2, (z1, z2) [N, dims * 2]
Tensor: new scores [N]
"""
# select num expected predictions from ious & score weihted score
topk_score = ious * weights * scores
topk_weighted_scores, topk_idx = topk_score.topk(min(len(scores), int(n_expected)))
boxes = boxes[topk_idx]
scores = scores[topk_idx]
n_missing_preds = torch.max(torch.tensor([0.], device=n_expected.device),
(n_expected - n_found).float())
# weigh predictions with high ious higher, penalty term for missing predictions
consolidated_score = scores.mean() * (1 - missing_weight * n_missing_preds / n_expected)
consolidated_boxes = (boxes * topk_weighted_scores.reshape(-1, 1)).sum(dim=0) / topk_weighted_scores.sum()
return consolidated_boxes, consolidated_score
from nndet.inference.ensembler.base import BaseEnsembler, BaseEnsemblerType, OverlapMap
from nndet.inference.ensembler.detection import BoxEnsembler
from nndet.inference.ensembler.segmentation import SegmentationEnsembler
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from os import PathLike
from pathlib import Path
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, TypeVar
import torch
from nndet.io.load import save_pickle
from nndet.utils.tensor import to_numpy
from nndet.utils.info import maybe_verbose_iterable
class BaseEnsembler(ABC):
ID = "abstract"
def __init__(self,
properties: Dict[str, Any],
parameters: Dict[str, Any],
device: Optional[Union[torch.device, str]] = None,
**kwargs):
"""
Base class to containerize and ensemble the predictions of a single case.
Call :method:`process_batch` to add batched predictions of a case
to the ensembler and :method:`add_model` to signal the next model
if multiple models are used.
Args:
properties: properties of the patient/case (e.g. tranpose axes)
parameters: parameters for ensembling
device: device to use for internal computations
**kwargs: parameters for ensembling
Notes:
Call :method:`add_model` before adding predictions.
"""
self.model_current = None
self.model_results = {}
self.model_weights = {}
self.properties = properties
self.case_result: Optional[Dict] = None
self.parameters = parameters
self.parameters.update(kwargs)
if device is None:
self.device = torch.device("cpu")
elif isinstance(device, str):
self.device = torch.device(device)
elif isinstance(device, torch.device):
self.device = device
else:
raise ValueError(f"Wrong type {type(device)} for device argument.")
@classmethod
def from_case(cls,
case: Dict,
properties: Optional[Dict] = None,
parameters: Optional[Dict] = None,
**kwargs,
):
"""
Primary way to instantiate this class. Automatically extracts all
properties and uses a default set of parameters for ensembling.
Args:
case: case which is predicted
properties: Additional properties. Defaults to None.
parameters: Additional parameters. Defaults to None.
"""
return cls(properties=properties, parameters=parameters, **kwargs)
def add_model(self,
name: Optional[str] = None,
model_weight: Optional[float] = None,
) -> str:
"""
This functions signales the ensembler to add a new model for internal
processing
Args:
name: Name of the model. If None, uses counts the models.
model_weight: Optional weight for this model. Defaults to None.
"""
if name is None:
name = len(self.model_weights) + 1
if name in self.model_results:
raise ValueError(f"Invalid model name, model {name} is already present")
if model_weight is None:
model_weight = 1.0
self.model_weights[name] = model_weight
self.model_results[name] = defaultdict(list)
self.model_current = name
return name
@abstractmethod
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch
Args:
result: predictions to save and ensemble
batch: input batch used for predictions (for additional meta data)
Raises:
NotImplementedError: Overwrite this function in subclasses for the
specific use case.
Warnings:
Make sure to move cached values to the CPU after they have been
processed.
"""
raise NotImplementedError
@abstractmethod
@torch.no_grad()
def get_case_result(self, restore: bool = False) -> Dict[str, torch.Tensor]:
"""
Retrieve the results of a single case
Args:
restore: restores predictions in original image space
Raises:
NotImplementedError: Overwrite this function in subclasses for the
specific use case.
Returns:
Dict[str, torch.Tensor]: the result of a single case
"""
raise NotImplementedError
def update_parameters(self, **parameters: Dict):
"""
Update internal parameters used for ensembling the results
Args:
parameters: parameters to update
"""
self.parameters.update(parameters)
@classmethod
@abstractmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any], Dict[str, Sequence[Any]]]:
"""
Return a set of parameters which can be used to sweep ensembling
parameters in a postprocessing step
Returns:
Dict[str, Any]: default state to start with
Dict[str, Sequence[Any]]]: Defines the values to search for each
parameter
"""
raise NotImplementedError
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name
Args:
target_dir: folder to save result to
name: name of case
**kwargs: data to save
"""
kwargs["properties"] = self.properties
kwargs["parameters"] = self.parameters
kwargs["model_current"] = self.model_current
kwargs["model_results"] = self.model_results
kwargs["model_weights"] = self.model_weights
kwargs["case_result"] = self.case_result
with open(Path(target_dir) / f"{name}_{self.ID}.pt", "wb") as f:
torch.save(kwargs, f)
def load_state(self, base_dir: PathLike, case_id: str) -> Dict:
"""
Path to result file
"""
ckp = torch.load(str(Path(base_dir) / f"{case_id}_{self.ID}.pt"))
self._load(ckp)
return ckp
def _load(self, state: Dict):
for key, item in state.items():
setattr(self, key, item)
@classmethod
def from_checkpoint(cls, base_dir: PathLike, case_id: str):
ckp = torch.load(str(Path(base_dir) / f"{case_id}_{cls.ID}.pt"))
t = cls(
properties=ckp["properties"],
parameters=ckp["parameters"],
)
t._load(ckp)
return t
@classmethod
def get_case_ids(cls, base_dir: PathLike):
return [c.stem.rsplit(f"_{cls.ID}", 1)[0]
for c in Path(base_dir).glob(f"*_{cls.ID}.pt")]
class OverlapMap:
def __init__(self, data_shape: Sequence[int]):
"""
Handler for overlap map
Args:
data_shape: spatial dimensions of data (
no batch dim and no channel dim!)
"""
self.overlap_map: torch.Tensor = \
torch.zeros(*data_shape, requires_grad=False, dtype=torch.float)
def add_overlap(self, crop: Sequence[slice]):
"""
Increase values of :param:`self.overlap_map` inside of crop
Args:
crop: defines crop. Negative values are assumed to be outside
of the data and thus discarded
"""
# discard leading indexes which could be due to batches and channels
if len(crop) > self.overlap_map.ndim:
crop = crop[-self.overlap_map.ndim:]
# clip crop to data shape
slicer = []
for data_shape, crop_dim in zip(tuple(self.overlap_map.shape), crop):
start = max(0, crop_dim.start)
stop = min(data_shape, crop_dim.stop)
slicer.append(slice(start, stop, crop_dim.step))
self.overlap_map[slicer] += 1
def mean_num_overlap_of_box(self, box: Sequence[int]) -> float:
"""
Extract mean number of overlaps from a bounding box area
Args:
box: defines bounding box (x1, y1, x2, y2, (z1, z2))
Returns:
int: mean number of overlaps
"""
slicer = [slice(int(box[0]), int(box[2])), slice(int(box[1]), int(box[3]))]
if len(box) == 6:
slicer.append(slice(int(box[4]), int(box[5])))
return torch.mean(self.overlap_map[slicer].float()).item()
def mean_num_overlap_of_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""
Extract mean number of overlaps from a bounding box area
Args:
boxes: defines multiple bounding boxes (x1, y1, x2, y2, (z1, z2))
[N, dim * 2]
Returns:
Tensor: mean number of overlaps per box [N]
"""
return torch.tensor(
[self.mean_num_overlap_of_box(box) for box in boxes]).to(
dtype=torch.float, device=boxes.device)
def avg(self) -> torch.Tensor:
"""
Compute mean over all overlaps
"""
return self.overlap_map.float().median()
def restore_mean(self, val):
"""
Generate a new overlap map filled with the specified value
"""
self.overlap_map = torch.zeros_like(self.overlap_map)
self.overlap_map = float(val)
def extract_results(source_dir: PathLike,
target_dir: PathLike,
ensembler_cls: Callable,
restore: bool,
**params,
) -> None:
"""
Compute case result from ensembler and save it
Args:
source_dir: directory which contains the saved predictions/state from
the ensembler class
target_dir: directory to save results
ensembler_cls: ensembler class for prediction
restore: if true, the results are converted into the opriginal image
space
"""
Path(target_dir).mkdir(parents=True, exist_ok=True)
for case_id in maybe_verbose_iterable(ensembler_cls.get_case_ids(source_dir)):
ensembler = ensembler_cls.from_checkpoint(base_dir=source_dir, case_id=case_id)
ensembler.update_parameters(**params)
pred = to_numpy(ensembler.get_case_result(restore=restore))
save_pickle(pred, Path(target_dir) / f"{case_id}_{ensembler_cls.ID}.pkl")
BaseEnsemblerType = TypeVar('BaseEnsemblerType', bound=BaseEnsembler)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Hashable, Union
import torch
import numpy as np
from scipy.stats import norm
from torch import Tensor
from loguru import logger
from nndet.inference.detection.model import batched_weighted_nms_model
from nndet.inference.detection import batched_nms_model, batched_nms_ensemble, \
batched_wbc_ensemble, wbc_nms_no_label_ensemble
from nndet.inference.ensembler.base import BaseEnsembler, OverlapMap
from nndet.inference.restore import restore_detection
from nndet.core.boxes import box_center, clip_boxes_to_image, remove_small_boxes
from nndet.utils.tensor import cat, to_device, to_dtype
class BoxEnsembler(BaseEnsembler):
ID = "boxes"
def __init__(self,
properties: Dict[str, Any],
parameters: Dict[str, Any],
box_key: str = 'pred_boxes',
score_key: str = 'pred_scores',
label_key: str = 'pred_labels',
data_key: str = 'data',
device: Optional[Union[torch.device, str]] = None,
**kwargs):
"""
Ensemble bounding box detections from tta and multiple models
Args:
properties: properties of the patient/case (e.g. tranpose axes)
parameters: parameters for ensembling
box_key: key where boxes are located inside prediction dict
score_key: key where scores are located inside prediction dict
label_key: key where labels are located inside prediction dict
data_key: key where data is located inside batch dict
device: device to use for internal computations
kwargs: passed to super class
"""
super().__init__(
properties=properties,
parameters=parameters,
device=device,
**kwargs,
)
# parameters to access information from predictions and batches
self.data_key = data_key
self.score_key = score_key
self.label_key = label_key
self.box_key = box_key
self.overlap_map = OverlapMap(tuple(self.properties["shape"]))
@classmethod
def from_case(cls,
case: Dict,
properties: Dict,
parameters: Optional[Dict] = None,
box_key: str = 'pred_boxes',
score_key: str = 'pred_scores',
label_key: str = 'pred_labels',
data_key: str = 'data',
device: Optional[Union[torch.device, str]] = None,
**kwargs,
):
"""
Primary way to instantiate this class. Automatically extracts all
properties and uses a default set of parameters for ensembling.
Args:
case: case which is predicted.
properties: Additional properties.
Required keys:
`transpose_backward`
`spacing_after_resampling`
`crop_bbox`
parameters: Additional parameters. Defaults to None.
box_key: key where boxes are located inside prediction dict
score_key: key where scores are located inside prediction dict
label_key: key where labels are located inside prediction dict
data_key: key where data is located inside batch dict
device: device to use for internal computations
"""
_parameters = cls.get_default_parameters()
_parameters.update(parameters)
_properties = {
"shape": case[data_key].shape[1:], # remove channel dim
"transpose_backward": properties["transpose_backward"],
"original_spacing": properties["original_spacing"],
"spacing_after_resampling": properties["spacing_after_resampling"],
"crop_bbox": properties["crop_bbox"],
"original_size_of_raw_data": properties["original_size_of_raw_data"],
"itk_origin": properties["itk_origin"],
"itk_spacing": properties["itk_spacing"],
"itk_direction": properties["itk_direction"],
}
return cls(
properties=_properties,
parameters=_parameters,
box_key=box_key,
score_key=score_key,
label_key=label_key,
data_key=data_key,
device=device,
**kwargs,
)
@classmethod
def get_default_parameters(cls):
"""
Generate default parameters for instantiation
Returns:
Dict:
`model_iou`: IoU for model nms function
`model_nms_fn`: function to use for model NMS
`model_topk`: number of predictions with the highest
probability to keep
`ensemble_iou`: IoU for ensembling the predictions of multiple
models
`ensemble_nms_fn`: ensemble predictions from multiple
models
`ensemble_nms_topk`: number of predictions with the highest
probability to keep
`ensemble_remove_small_boxes`: minimum size of the box
`ensemble_score_thresh`: minimum probability
"""
return {
# single model
"model_iou": 0.1,
"model_nms_fn": batched_nms_model,
"model_score_thresh": 0.0,
"model_topk": 1000,
"model_detections_per_image": 100,
# ensemble multiple models
"ensemble_iou": 0.5,
"ensemble_nms_fn": batched_wbc_ensemble,
"ensemble_topk": 1000,
"remove_small_boxes": 1e-2,
"ensemble_score_thresh": 0.0,
}
def postprocess_image(self,
boxes: torch.Tensor,
probs: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
shape: Optional[Tuple[int]] = None
) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""
Postprocessing of a single image
select topk predictions -> score threshold -> clipping -> \
remove small boxes -> nms
Args:
boxes: predicted deltas for proposals [N, dim * 2]
probs: predicted logits for boxes [N]
labels: predicted labels for boxes [N]
weights: weight for each box [N]
Returns:
torch.Tensor: postprocessed boxes
torch.Tensor: postprocessed probs
torch.Tensor: postprocessed labels
torch.Tensor: postprocessed weights
"""
p_sorted, idx_sorted = probs.sort(descending=True)
idx_sorted = idx_sorted[:self.parameters["model_topk"]]
p_sorted = p_sorted[:self.parameters["model_topk"]]
keep_idxs = p_sorted > self.parameters["model_score_thresh"]
idx_sorted = idx_sorted[keep_idxs]
b, p, l, w = boxes[idx_sorted], probs[idx_sorted], labels[idx_sorted], weights[idx_sorted]
b = clip_boxes_to_image(b, shape)
# After clipping we could have boxes with volume 0 which we definitely
# need to remove because of the IoU computation
keep = remove_small_boxes(
b, min_size=self.parameters["remove_small_boxes"])
b, p, l, w = b[keep], p[keep], l[keep], w[keep]
_boxes, _probs, _labels, _weights = self.parameters["model_nms_fn"](
boxes=b, scores=p, labels=l, weights=w,
iou_thresh=self.parameters["model_iou"],
)
# predictions are sorted
_boxes = _boxes[:self.parameters.get("model_detections_per_image", 1000)]
_probs = _probs[:self.parameters.get("model_detections_per_image", 1000)]
_labels = _labels[:self.parameters.get("model_detections_per_image", 1000)]
_weights = _weights[:self.parameters.get("model_detections_per_image", 1000)]
return _boxes, _probs, _labels, _weights
@staticmethod
def _apply_offsets_to_boxes(boxes: List[Tensor],
tile_offset: Sequence[Sequence[int]],
) -> List[Tensor]:
"""
Apply offset to bounding boxes to position them correctly inside
the whole case
Args:
boxes: predicted boxes [N, dims * 2]
[x1, y1, x2, y2, (z1, z2))
tile_offset: defines offset for each tile
Returns:
List[Tensor]: bounding boxes with respect to origin of whole case
"""
offset_boxes = []
for img_boxes, offset in zip(boxes, tile_offset):
if img_boxes.nelement() == 0:
offset_boxes.append(img_boxes)
continue
offset = Tensor(offset).to(img_boxes)
_boxes = img_boxes.clone()
_boxes[:, 0] += offset[0]
_boxes[:, 1] += offset[1]
_boxes[:, 2] += offset[0]
_boxes[:, 3] += offset[1]
if img_boxes.shape[1] == 6:
_boxes[:, 4] += offset[2]
_boxes[:, 5] += offset[2]
offset_boxes.append(_boxes)
return offset_boxes
def restore_prediction(self, boxes: Tensor):
"""
Restore predictions in the original image space
Args:
boxes: predicted boxes [N, dims * 2] (x1, y1, x2, y2, (z1, z2))
Returns:
Tensor: boxes in original image space [N, dims * 2]
(x1, y1, x2, y2, (z1, z2))
"""
_old_dtype = boxes.dtype
boxes_np = restore_detection(
boxes.detach().cpu().numpy(),
transpose_backward=self.properties["transpose_backward"],
original_spacing=self.properties["original_spacing"],
spacing_after_resampling=self.properties["spacing_after_resampling"],
crop_bbox=self.properties["crop_bbox"],
)
boxes = torch.from_numpy(boxes_np).to(dtype=_old_dtype)
return boxes
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name
Args:
target_dir: folder to save result to
name: name of case
Notes:
The device is not saved inside the checkpoint and everything
will be loaded on the CPU.
"""
super().save_state(
target_dir=target_dir,
name=name,
score_key=self.score_key,
label_key=self.label_key,
box_key=self.box_key,
data_key=self.data_key,
overlap_map=self.overlap_map,
**kwargs,
)
@classmethod
def from_checkpoint(cls, base_dir: PathLike, case_id: str, **kwargs):
ckp = torch.load(str(Path(base_dir) / f"{case_id}_{cls.ID}.pt"))
t = cls(
properties=ckp["properties"],
parameters=ckp["parameters"],
box_key=ckp["box_key"],
score_key=ckp["score_key"],
label_key=ckp["label_key"],
data_key=ckp["data_key"],
**kwargs
)
t._load(ckp)
return t
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any],
Dict[str, Sequence[Any]]]:
# iou_threshs = np.linspace(0.0, 0.8, 9)
iou_threshs = np.linspace(0.0, 0.5, 6)
iou_threshs[0] = 1e-5
small_boxes_thresh = np.linspace(2., 7., 6)
param_sweep = {
# ensemble multiple models
"ensemble_iou": iou_threshs,
"model_score_thresh": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
# "remove_small_boxes": small_boxes_thresh,
}
return cls.get_default_parameters(), param_sweep
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
(the boxes are clipped to the case size in the ensembling step)
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.box_key`: List[Tensor]: predicted boxes (relative
to patch coordinates)
`self.score_key` List[Tensor]: score for each tensor
`self.label_key`: List[Tensor] label prediction for each box
batch: input batch for detector
`tile_origin: origin of crop with respect to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
Warnings:
Make sure to move cached values to the CPU after they have been
processed.
"""
tile_origins = [to for to in zip(*batch["tile_origin"])]
tile_size = batch[self.data_key].shape[2:]
boxes = []
scores = []
labels = []
for b, s, l in zip(result[self.box_key], result[self.score_key], result[self.label_key]):
_boxes, _scores, _labels, _ = self.postprocess_image(
boxes=b.float(),
probs=s.float(),
labels=l.float(),
weights=torch.ones_like(s).float(),
shape=tuple(tile_size),
)
boxes.append(_boxes.cpu())
scores.append(_scores.cpu())
labels.append(_labels.cpu())
centers = [box_center(img_boxes) if img_boxes.numel() > 0 else Tensor([]).to(img_boxes)
for img_boxes in boxes]
weights = [self._get_box_in_tile_weight(c, tile_size) for c in centers]
weights = [w * self.model_weights[self.model_current] for w in weights]
boxes = self._apply_offsets_to_boxes(boxes, tile_origins)
self.model_results[self.model_current]["boxes"].extend(boxes)
self.model_results[self.model_current]["scores"].extend(scores)
self.model_results[self.model_current]["labels"].extend(labels)
self.model_results[self.model_current]["weights"].extend(weights)
crops_reshaped = list(zip(*batch["crop"]))
self.model_results[self.model_current]["crops"].extend(crops_reshaped)
for crop in crops_reshaped:
self.overlap_map.add_overlap(crop)
@staticmethod
def _get_box_in_tile_weight(box_centers: Tensor,
tile_size: Sequence[int],
) -> Tensor:
"""
Assign boxes at the corners of tiles a lower weight (weight
is drawn form a scaled normal distribution)
Args:
box_centers: center predicted box [N, dims]
tile_size: size the of patch/tile
Returns:
Tensor: weight for each bounding box [N]
"""
if box_centers.numel() > 0:
all_weights = []
centers_np = box_centers.detach().cpu().numpy()
for center_np in centers_np:
weight = np.mean([
norm.pdf(bc, loc=ps, scale=ps * 0.8) * np.sqrt(2 * np.pi) * ps * 0.8
for bc, ps in zip(center_np, np.array(tile_size) / 2)])
all_weights.append([weight])
return torch.from_numpy(np.concatenate(all_weights)).to(box_centers)
else:
return Tensor([]).to(box_centers)
@torch.no_grad()
def get_case_result(self,
restore: bool = False,
names: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Tensor]:
"""
Process all the batches and models and create the final prediction
Args:
restore: restore prediction in the original image space
names: name of the models to use. By default all models are used.
Returns:
Dict: final result
`pred_boxes`: predicted box locations
[N, dims * 2] (x1, y1, x2, y2, (z1, z2))
`pred_scores`: predicted probability per box [N]
`pred_labels`: predicted label per box [N]
`restore`: indicate whether predictions were restored in
original image space
`original_size_of_raw_data`: image shape befor preprocessing
`itk_origin`: itk origin of image before preprocessing
`itk_spacing`: itk spacing of image before preprocessing
`itk_direction`: itk direction of image before preprocessing
"""
if names is None:
names = list(self.model_results.keys())
boxes, probs, labels, weights = [], [], [], []
for name in names:
_boxes, _probs, _labels, _weights = self.process_model(name)
boxes.append(_boxes)
probs.append(_probs)
labels.append(_labels)
weights.append(_weights)
boxes, probs, labels = self.process_ensemble(
boxes=boxes, probs=probs, labels=labels,
weights=weights,
)
if restore:
boxes = self.restore_prediction(boxes)
return {
"pred_boxes": boxes,
"pred_scores": probs,
"pred_labels": labels,
"restore": restore,
"original_size_of_raw_data": self.properties["original_size_of_raw_data"],
"itk_origin": self.properties["itk_origin"],
"itk_spacing": self.properties["itk_spacing"],
"itk_direction": self.properties["itk_direction"],
}
def process_model(self, name: Hashable) ->\
Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Process the output of a single model on the whole scan
topk candidates -> nms
Args:
name: name of model to process
Returns:
Tensor: filtered boxes
Tensor: filtered probs
Tensor: filtered labels
idx: indices kept from original ordered data
"""
# concatenate batches
boxes = cat(self.model_results[name]["boxes"], dim=0)
probs = cat(self.model_results[name]["scores"], dim=0)
labels = cat(self.model_results[name]["labels"], dim=0)
weights = cat(self.model_results[name]["weights"], dim=0)
return boxes, probs, labels, weights
def process_ensemble(self, boxes: List[Tensor], probs: List[Tensor],
labels: List[Tensor], weights: List[Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble predictions from multiple models
Args:
boxes: predicted boxes List[[N, dims * 2]]
(x1, y1, x2, y2, (z1, z2))
probs: predicted probabilities List[[N]]
labels: predicted label List[[N]]
weights: additional weight List[[N]]
Returns:
Tensor: ensembled box predictions
Tensor: ensembled probabilities
Tensor: ensembled labels
"""
boxes = cat(boxes, dim=0)
probs = cat(probs, dim=0)
labels = cat(labels, dim=0)
weights = cat(weights, dim=0)
_, idx = probs.sort(descending=True)
idx = idx[:self.parameters["ensemble_topk"]]
boxes = boxes[idx]
probs = probs[idx]
labels = labels[idx]
weights = weights[idx]
n_exp_preds = self.overlap_map.mean_num_overlap_of_boxes(boxes)
boxes, probs, labels = self.parameters["ensemble_nms_fn"](
boxes, probs, labels,
weights=weights,
iou_thresh=self.parameters["model_iou"],
n_exp_preds=n_exp_preds,
score_thresh=self.parameters["ensemble_score_thresh"],
)
return boxes.cpu(), probs.cpu(), labels.cpu()
class BoxEnsemblerLW(BoxEnsembler):
"""
Uses different computation for box weight, much faster than box ensembler.
"""
@staticmethod
def _get_box_in_tile_weight(box_centers: Tensor,
tile_size: Sequence[int],
) -> Tensor:
"""
Assign boxes near the corner a lower weight.
The middle has a plateau with weight one, starting from patchsize / 2
the weights decreases linearly until 0.5 is reached.
Args:
box_centers: center predicted box [N, dims]
tile_size: size the of patch/tile
Returns:
Tensor: weight for each bounding box [N]
"""
plateau_length = 0.5 # adjust width of plateau and min weight
if box_centers.numel() > 0:
tile_center = torch.tensor(tile_size).to(box_centers) / 2. # [dims]
max_dist = tile_center.norm(p=2) # [1]
boxes_dist = (box_centers - tile_center[None]).norm(p=2, dim=1) # [N]
weight = -(boxes_dist / max_dist - plateau_length).clamp_(min=0) + 1
return weight
else:
return Tensor([]).to(box_centers)
class BoxEnsemblerFastest(BoxEnsemblerLW):
"""
Uses the fastest but not necessarily most precise box ensembling strategy
Only save top `num_reduced_cache` boxes for ensembling
Uses a linear box weight
Uses the mean over the whole overlap map. Depending on overlap
and patch stride this is not correct.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reduced_cache = False
self.num_reduced_cache = 8000
self.overlap_map_mean = None
@classmethod
def get_default_parameters(cls):
"""
Generate default parameters for instantiation
Returns:
Dict:
`model_iou`: IoU for model nms function
`model_nms_fn`: function to use for model NMS
`model_topk`: number of predictions with the highest
probability to keep
`ensemble_iou`: IoU for ensembling the predictions of multiple
models
`ensemble_nms_fn`: ensemble predictions from multiple
models
`ensemble_nms_topk`: number of predictions with the highest
probability to keep
`ensemble_remove_small_boxes`: minimum size of the box
`ensemble_score_thresh`: minimum probability
"""
return {
# single model
"model_iou": 0.1,
"model_nms_fn": batched_nms_model,
"model_score_thresh": 0.1,
"model_topk": 1000,
"model_detections_per_image": 1000,
# ensemble multiple models
"ensemble_iou": 0.5,
"ensemble_nms_fn": batched_wbc_ensemble,
"ensemble_topk": 1000,
"remove_small_boxes": 1e-2,
"ensemble_score_thresh": 0.0,
}
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any],
Dict[str, Sequence[Any]]]:
iou_threshs = np.linspace(0.0, 0.5, 6)
iou_threshs[0] = 1e-5
small_boxes_thresh = [1e-2] + np.linspace(2., 7., 6).tolist()
param_sweep = {
# single model
"model_iou": iou_threshs,
# ensemble multiple models
"ensemble_iou": iou_threshs,
"remove_small_boxes": small_boxes_thresh,
"model_score_thresh": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
}
return cls.get_default_parameters(), param_sweep
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
(the boxes are clipped to the case size in the ensembling step)
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.box_key`: List[Tensor]: predicted boxes (relative
to patch coordinates)
`self.score_key` List[Tensor]: score for each tensor
`self.label_key`: List[Tensor] label prediction for each box
batch: input batch for detector
`tile_origin: origin of crop with respect to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
"""
if self.reduced_cache:
logger.warning("Ensembler was already reduced, need to rerun reduce_cache "
"later and restore overlap map with proxy mean.")
self.overlap_map.restore_mean(self.overlap_map_mean)
self.reduced_cache = False
boxes = [r.half().cpu() for r in result[self.box_key]]
scores = [r.half().cpu() for r in result[self.score_key]]
labels = [r.half().cpu() for r in result[self.label_key]]
centers = [box_center(img_boxes) if img_boxes.numel() > 0 else Tensor([]).to(img_boxes)
for img_boxes in boxes]
tile_origins = [to for to in zip(*batch["tile_origin"])]
tile_size = batch[self.data_key].shape[2:]
weights = [self._get_box_in_tile_weight(c, tile_size) for c in centers]
weights = [w * self.model_weights[self.model_current] for w in weights]
boxes = self._apply_offsets_to_boxes(boxes, tile_origins)
self.model_results[self.model_current]["boxes"].extend(boxes)
self.model_results[self.model_current]["scores"].extend(scores)
self.model_results[self.model_current]["labels"].extend(labels)
self.model_results[self.model_current]["weights"].extend(weights)
crops_reshaped = list(zip(*batch["crop"]))
self.model_results[self.model_current]["crops"].extend(crops_reshaped)
for crop in crops_reshaped:
self.overlap_map.add_overlap(crop)
@staticmethod
def _get_box_in_tile_weight(box_centers: Tensor,
tile_size: Sequence[int],
) -> Tensor:
"""
Assign boxes near the corner a lower weight.
The middle has a plateau with weight one, starting from patchsize / 2
the weights decreases linearly until 0.5 is reached.
Args:
box_centers: center predicted box [N, dims]
tile_size: size the of patch/tile
Returns:
Tensor: weight for each bounding box [N]
"""
plateau_length = 0.5 # adjust width of plateau and min weight
if box_centers.numel() > 0:
tile_center = torch.tensor(tile_size).to(box_centers) / 2. # [dims]
max_dist = tile_center.norm(p=2) # [1]
boxes_dist = (box_centers - tile_center[None]).norm(p=2, dim=1) # [N]
weight = -(boxes_dist / max_dist - plateau_length).float().clamp_(min=0).half() + 1
return weight
else:
return Tensor([]).to(box_centers).half()
def process_model(self,
name: Hashable,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Process the output of a single model on the whole scan
topk candidates -> nms
Args:
name: name of model to process
Returns:
Tensor: processed boxes
Tensor: processed probs
Tensor: processed labels
Tensor: processed weights
"""
boxes = to_device(self.model_results[name]["boxes"], device=self.device)
probs = to_device(self.model_results[name]["scores"], device=self.device)
labels = to_device(self.model_results[name]["labels"], device=self.device)
weights = to_device(self.model_results[name]["weights"], device=self.device)
model_boxes = []
model_probs = []
model_labels = []
model_weights = []
for b, p, l, w in zip(boxes, probs, labels, weights):
if b.numel() > 0:
_b, _p, _l, _w = self.postprocess_image(
boxes=b.float(),
probs=p.float(),
labels=l.float(),
weights=w.float(),
shape=tuple(self.properties["shape"]),
)
model_boxes.append(_b)
model_probs.append(_p)
model_labels.append(_l)
model_weights.append(_w)
return cat(model_boxes), cat(model_probs), cat(model_labels), cat(model_weights)
def process_ensemble(self,
boxes: List[Tensor],
probs: List[Tensor],
labels: List[Tensor],
weights: List[Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble predictions from multiple models
Args:
boxes: predicted boxes List[[N, dims * 2]]
(x1, y1, x2, y2, (z1, z2))
probs: predicted probabilities List[[N]]
labels: predicted label List[[N]]
weights: additional weight List[[N]]
Returns:
Tensor: ensembled box predictions
Tensor: ensembled probabilities
Tensor: ensembled labels
"""
boxes = cat(boxes, dim=0)
probs = cat(probs, dim=0)
labels = cat(labels, dim=0)
weights = cat(weights, dim=0)
_, idx = probs.sort(descending=True)
idx = idx[:self.parameters["ensemble_topk"]]
boxes = boxes[idx]
probs = probs[idx]
labels = labels[idx]
weights = weights[idx]
n_exp_preds = self.overlap_map_mean.expand(len(boxes)).to(boxes)
boxes, probs, labels = self.parameters["ensemble_nms_fn"](
boxes, probs, labels,
weights=weights,
iou_thresh=self.parameters["model_iou"],
n_exp_preds=n_exp_preds,
score_thresh=self.parameters["ensemble_score_thresh"],
)
return boxes.cpu(), probs.cpu(), labels.cpu()
@torch.no_grad()
def get_case_result(self,
restore: bool = False,
names: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Tensor]:
"""
Process all the batches and models and create the final prediction
Args:
restore: restore prediction in the original image space
names: name of the models to use. By default all models are used.
Returns:
Dict: final result
`pred_boxes`: predicted box locations
[N, dims * 2] (x1, y1, x2, y2, (z1, z2))
`pred_scores`: predicted probability per box [N]
`pred_labels`: predicted label per box [N]
`restore`: indicate whether predictions were restored in
original image space
`original_size_of_raw_data`: image shape befor preprocessing
`itk_origin`: itk origin of image before preprocessing
`itk_spacing`: itk spacing of image before preprocessing
`itk_direction`: itk direction of image before preprocessing
"""
self.reduce_cache()
return super().get_case_result(restore=restore, names=names)
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name. Before saving the state, the cache will
be reduced to a predefined number of predictions to for memory
and computational reasons
Args:
target_dir: folder to save result to
name: name of case
Notes:
The device is not saved inside the checkpoint and everything
will be loaded on the CPU.
"""
self.reduce_cache()
return BaseEnsembler.save_state(
self,
target_dir=target_dir,
name=name,
reduced_cache=self.reduced_cache,
score_key=self.score_key,
label_key=self.label_key,
box_key=self.box_key,
data_key=self.data_key,
overlap_map_mean=self.overlap_map_mean,
**kwargs,
)
def reduce_cache(self):
"""
Only save a subset of all boxes for further evaluations
"""
if not self.reduced_cache:
self.reduced_cache = True
# we use the mean here to save time ...
self.overlap_map_mean = self.overlap_map.avg()
for model in self.model_results.keys():
batch_idx = self.build_batch_indices(self.model_results[model]["scores"])
boxes = cat(self.model_results[model]["boxes"])
probs = cat(self.model_results[model]["scores"])
labels = cat(self.model_results[model]["labels"])
weights = cat(self.model_results[model]["weights"])
if len(probs) > self.num_reduced_cache:
_, idx_sorted = probs.sort(descending=True)
idx_sorted = idx_sorted[:self.num_reduced_cache]
batch_idx_keep = [[b for b in bix if b in idx_sorted] for bix in batch_idx]
assert len(batch_idx_keep) == len(self.model_results[model]["scores"])
self.model_results[model]["boxes"] = [boxes[i] for i in batch_idx_keep]
self.model_results[model]["scores"] = [probs[i] for i in batch_idx_keep]
self.model_results[model]["labels"] = [labels[i] for i in batch_idx_keep]
self.model_results[model]["weights"] = [weights[i] for i in batch_idx_keep]
@staticmethod
def build_batch_indices(b: Sequence[Tensor]) -> List[List[int]]:
idx = []
num_elem = 0
for _b in b:
if _b.numel() > 0:
additional_elem = len(_b)
idx.append(list(range(num_elem, num_elem + additional_elem)))
num_elem += additional_elem
else:
idx.append([])
return idx
class BoxEnsemblerSelective(BoxEnsembler):
def __init__(self,
properties: Dict[str, Any],
parameters: Dict[str, Any],
box_key: str = 'pred_boxes',
score_key: str = 'pred_scores',
label_key: str = 'pred_labels',
data_key: str = 'data',
device: Optional[Union[torch.device, str]] = None,
**kwargs,
):
"""
Ensemble bounding box detections from tta and multiple models
This uses a different ensembling strategy which is faster and allows
for model IoU optimization.
Args:
properties: properties of the patient/case (e.g. tranpose axes)
parameters: parameters for ensembling
box_key: key where boxes are located inside prediction dict
score_key: key where scores are located inside prediction dict
label_key: key where labels are located inside prediction dict
data_key: key where data is located inside batch dict
device: device to use for internal computations
kwargs: passed to super class
"""
super().__init__(
properties=properties,
parameters=parameters,
device=device,
box_key=box_key,
score_key=score_key,
label_key=label_key,
data_key=data_key,
**kwargs,
)
self.overlap_map = None
@classmethod
def get_default_parameters(cls):
"""
Generate default parameters for instantiation
Returns:
Dict:
`model_iou`: IoU for model nms function
`model_nms_fn`: function to use for model NMS
`model_topk`: number of predictions with the highest
probability to keep
`ensemble_iou`: IoU for ensembling the predictions of multiple
models
`ensemble_nms_fn`: ensemble predictions from multiple
models
`ensemble_nms_topk`: number of predictions with the highest
probability to keep
`ensemble_remove_small_boxes`: minimum size of the box
`ensemble_score_thresh`: minimum probability
"""
return {
# single model
"model_iou": 0.1,
"model_nms_fn": batched_weighted_nms_model,
"model_score_thresh": 0.0,
"model_topk": 1000,
"model_detections_per_image": 100,
# ensemble multiple models
"ensemble_iou": 0.5,
"ensemble_nms_fn": batched_wbc_ensemble,
"ensemble_topk": 1000,
"remove_small_boxes": 1e-2,
"ensemble_score_thresh": 0.0,
}
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any],
Dict[str, Sequence[Any]]]:
# iou_threshs = np.linspace(0.0, 0.8, 9)
iou_threshs = np.linspace(0.0, 0.5, 6)
iou_threshs[0] = 1e-5
small_boxes_thresh = [1e-2] + np.linspace(2., 7., 6).tolist()
param_sweep = {
# single model
"model_iou": iou_threshs,
"model_nms_fn": [
batched_weighted_nms_model,
batched_nms_model,
],
# ensemble multiple models
"ensemble_iou": iou_threshs,
"model_score_thresh": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"remove_small_boxes": small_boxes_thresh,
}
return cls.get_default_parameters(), param_sweep
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
(the boxes are clipped to the case size in the ensembling step)
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.box_key`: List[Tensor]: predicted boxes (relative
to patch coordinates)
`self.score_key` List[Tensor]: score for each tensor
`self.label_key`: List[Tensor] label prediction for each box
batch: input batch for detector
`tile_origin: origin of crop with respect to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
"""
boxes = [r.float().cpu() for r in result[self.box_key]]
scores = [r.float().cpu() for r in result[self.score_key]]
labels = [r.float().cpu() for r in result[self.label_key]]
centers = [box_center(img_boxes) if img_boxes.numel() > 0 else Tensor([]).to(img_boxes)
for img_boxes in boxes]
tile_origins = [to for to in zip(*batch["tile_origin"])]
tile_size = batch[self.data_key].shape[2:]
weights = [self._get_box_in_tile_weight(c, tile_size) for c in centers]
weights = [w * self.model_weights[self.model_current] for w in weights]
boxes = self._apply_offsets_to_boxes(boxes, tile_origins)
self.model_results[self.model_current]["boxes"].extend(boxes)
self.model_results[self.model_current]["scores"].extend(scores)
self.model_results[self.model_current]["labels"].extend(labels)
self.model_results[self.model_current]["weights"].extend(weights)
# self.model_results[self.model_current]["crops"].extend(
# list(zip(*batch["crop"])))
@staticmethod
def _get_box_in_tile_weight(box_centers: Tensor,
tile_size: Sequence[int],
) -> Tensor:
"""
Assign boxes near the corner a lower weight.
The midle has a plateau with weight one, starting from patchsize / 2
the weights decreases linearly until 0.5 is reached.
Args:
box_centers: center predicted box [N, dims]
tile_size: size the of patch/tile
Returns:
Tensor: weight for each bounding box [N]
"""
plateau_length = 0.5 # adjust width of plateau and min weight
if box_centers.numel() > 0:
tile_center = torch.tensor(tile_size).to(box_centers) / 2. # [dims]
max_dist = tile_center.norm(p=2) # [1]
boxes_dist = (box_centers - tile_center[None]).norm(p=2, dim=1) # [N]
weight = -(boxes_dist / max_dist - plateau_length).clamp_(min=0) + 1
return weight
else:
return Tensor([]).to(box_centers)
def process_model(self, name: Hashable) ->\
Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Process the output of a single model on the whole scan
topk candidates -> nms
Args:
name: name of model to process
Returns:
Tensor: processed boxes
Tensor: processed probs
Tensor: processed labels
Tensor: processed weights
"""
# collect predictions on whole case and apply postprocessing
boxes = cat(self.model_results[name]["boxes"]).to(self.device)
probs = cat(self.model_results[name]["scores"]).to(self.device)
labels = cat(self.model_results[name]["labels"]).to(self.device)
weights = cat(self.model_results[name]["weights"]).to(self.device)
return self.postprocess_image(
boxes=boxes,
probs=probs,
labels=labels,
weights=weights,
shape=tuple(self.properties["shape"]),
)
def process_ensemble(self, boxes: List[Tensor], probs: List[Tensor],
labels: List[Tensor], weights: List[Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble predictions from multiple models
Args:
boxes: predicted boxes List[[N, dims * 2]]
(x1, y1, x2, y2, (z1, z2))
probs: predicted probabilities List[[N]]
labels: predicted label List[[N]]
weights: additional weight List[[N]]
Returns:
Tensor: ensembled box predictions
Tensor: ensembled probabilities
Tensor: ensembled labels
"""
num_models = len(boxes)
boxes = cat(boxes, dim=0)
probs = cat(probs, dim=0)
labels = cat(labels, dim=0)
weights = cat(weights, dim=0)
_, idx = probs.sort(descending=True)
idx = idx[:self.parameters["ensemble_topk"]]
boxes = boxes[idx]
probs = probs[idx]
labels = labels[idx]
weights = weights[idx]
n_exp_preds = torch.tensor([num_models] * len(boxes)).to(boxes)
boxes, probs, labels = self.parameters["ensemble_nms_fn"](
boxes, probs, labels,
weights=weights,
iou_thresh=self.parameters["ensemble_iou"],
n_exp_preds=n_exp_preds,
score_thresh=self.parameters["ensemble_score_thresh"],
)
return boxes.cpu(), probs.cpu(), labels.cpu()
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name.
This version only saves the topk model predictions to speed
up loading.
Args:
target_dir: folder to save result to
name: name of case
Notes:
The device is not saved inside the checkpoint and everything
will be loaded on the CPU.
"""
for model in self.model_results.keys():
boxes = cat(self.model_results[model]["boxes"])
probs = cat(self.model_results[model]["scores"])
labels = cat(self.model_results[model]["labels"])
weights = cat(self.model_results[model]["weights"])
if len(probs) > self.parameters["model_topk"]:
_, idx_sorted = probs.sort(descending=True)
idx_sorted = idx_sorted[:self.parameters["model_topk"]]
self.model_results[model]["boxes"] = boxes[idx_sorted]
self.model_results[model]["scores"] = probs[idx_sorted]
self.model_results[model]["labels"] = labels[idx_sorted]
self.model_results[model]["weights"] = weights[idx_sorted]
return super().save_state(target_dir=target_dir, name=name, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment