Commit a28153dd authored by mibaumgartner's avatar mibaumgartner
Browse files

inference

parent c133c39c
from nndet.inference.ensembler import BaseEnsemblerType, BaseEnsembler, BoxEnsembler, SegmentationEnsembler
from nndet.inference.predictor import PredictorType, Predictor
from nndet.inference.sweeper import SweeperType, Sweeper, BoxSweeper
from nndet.inference.restore import restore_detection, restore_fmap
from nndet.inference.detection.wbc import batched_wbc, wbc
from nndet.inference.detection.model import batched_nms_model
from nndet.inference.detection.ensemble import batched_wbc_ensemble, batched_nms_ensemble, \
wbc_nms_no_label_ensemble
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple
from torch import Tensor
from nndet.detection.boxes import batched_nms, nms
from nndet.inference.detection import batched_wbc
def batched_nms_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble nms for ensembler (same as batched nms with adjusted signature)
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box (ignored in this function)
iou_thresh: IoU threshold for nms
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: boxes
Tensor: scores
Tensor: labels
"""
keep = batched_nms(boxes=boxes, scores=scores,
idxs=labels, iou_threshold=iou_thresh,
)
return boxes[keep], scores[keep], labels[keep]
def batched_wbc_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
*args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble wbc for ensembler (same as batched nms with adjusted signature)
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box (ignored in this function)
iou_thresh: IoU threshold for nms
n_exp_preds: number of expected predictions per box
score_thresh: minimum score
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: boxes
Tensor: scores
Tensor: labels
"""
boxes, scores, labels = batched_wbc(
boxes, scores, labels, weights=weights,
n_exp_preds=n_exp_preds,
iou_thresh=iou_thresh,
score_thresh=score_thresh,
)
return boxes, scores, labels
def wbc_nms_no_label_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
*args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""
Normal wbc -> nms without class labels
This results in a single prediction per position regardless of the class
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box (ignored in this function)
iou_thresh: IoU threshold for nms
n_exp_preds: number of expected predictions per box
score_thresh: minimum score
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: boxes
Tensor: scores
Tensor: labels
"""
boxes, scores, labels = batched_wbc(
boxes, scores, labels, weights=weights,
n_exp_preds=n_exp_preds,
iou_thresh=iou_thresh,
score_thresh=score_thresh,
)
keep = nms(boxes, scores, iou_thresh)
return boxes[keep], scores[keep], labels[keep]
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple
from torch import Tensor
import torch
from nndet.detection.boxes import batched_nms
def batched_nms_model(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Model nms for ensembler (same as batched nms with adjusted signature)
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box
iou_thresh: IoU threshold for nms
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: sorted boxes
Tensor: sorted scores (descending)
Tensor: sorted labels
Tensor: sorted weights
"""
keep = batched_nms(boxes=boxes, scores=scores,
idxs=labels, iou_threshold=iou_thresh,
)
return boxes[keep], scores[keep], labels[keep], weights[keep]
def batched_weighted_nms_model(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Model nms for ensembler (same as batched nms with adjusted signature)
Args:
boxes: predicted boxes
scores: predicted scores
labels: predicted labels
weights: weight per box
iou_thresh: IoU threshold for nms
*args: kept for compatibility
**kwargs: kept for compatibility
Returns:
Tensor: sorted boxes
Tensor: sorted scores (descending)
Tensor: sorted labels
Tensor: sorted weights
"""
new_scores = scores * weights
keep = batched_nms(boxes=boxes, scores=new_scores, idxs=labels, iou_threshold=iou_thresh)
new_weights = torch.ones_like(weights)
return boxes[keep], scores[keep], labels[keep], new_weights[keep]
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple
import torch
from torch import Tensor
from nndet.detection.boxes import batched_nms, nms
from nndet.inference.detection import batched_wbc
def batched_nms_model(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
keep = batched_nms(boxes=boxes, scores=scores,
idxs=labels, iou_threshold=iou_thresh,
)
return boxes[keep], scores[keep], labels[keep], weights[keep]
def batched_nms_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
*args, **kwargs,
) -> Tuple[Tensor, Tensor, Tensor]:
keep = batched_nms(boxes=boxes, scores=scores,
idxs=labels, iou_threshold=iou_thresh,
)
return boxes[keep], scores[keep], labels[keep]
def batched_wbc_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
*args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
boxes, scores, labels = batched_wbc(
boxes, scores, labels, weights=weights,
n_exp_preds=n_exp_preds,
iou_thresh=iou_thresh,
score_thresh=score_thresh,
)
return boxes, scores, labels
def wbc_nms_no_label_ensemble(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
*args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
boxes, scores, labels = batched_wbc(
boxes, scores, labels, weights=weights,
n_exp_preds=n_exp_preds,
iou_thresh=iou_thresh,
score_thresh=score_thresh,
)
keep = nms(boxes, scores, iou_thresh)
return boxes[keep], scores[keep], labels[keep]
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch import Tensor
from typing import Tuple
from torch._C import device
from nndet.detection.boxes import box_iou, box_area
__all__ = ["batched_wbc", "wbc"]
def batched_wbc(
boxes: Tensor,
scores: Tensor,
labels: Tensor,
weights: Tensor,
iou_thresh: float,
n_exp_preds: Tensor,
score_thresh: float,
use_area: bool = False,
missing_weight: float = 1.,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Computed weighted box clustering per class
Args:
boxes: predicted boxes (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
scores: predicted scores [N]
labels: predicted labels [N]
weights: weight for each box [N] (gaussian weighting of boxes near
corners need to be included in this weight)
iou_thresh: iou threshold used for clustering boxes
n_exp_preds: number of expected predictions per box (computed as the
mean number predictions inside the bounding box)
score_thresh: minimum score of predictions after clustering
use_area: assigns higher weights to larger boxes based on
empirical observations indicating an increase in image
evidence from larger areas.
missing_weight: weight for score dampening when predictions are missing
Returns:
Tensor: clustered boxes
Tensor: clustered scores
Tensor: labels
"""
clustered_boxes = []
clustered_scores = []
clustered_labels = []
for label in labels.unique():
_labels_mask = labels == label
_boxes = boxes[_labels_mask]
_scores = scores[_labels_mask]
_weights = weights[_labels_mask]
_n_exp_preds = n_exp_preds[_labels_mask]
b, s = wbc(_boxes, _scores,
weights=_weights, n_exp_preds=_n_exp_preds,
iou_thresh=iou_thresh, score_thresh=score_thresh,
use_area=use_area,
missing_weight=missing_weight,
)
clustered_boxes.append(b)
clustered_scores.append(s)
clustered_labels.append(torch.empty_like(s).fill_(label))
if clustered_boxes:
return (torch.cat(clustered_boxes, dim=0),
torch.cat(clustered_scores, dim=0),
torch.cat(clustered_labels, dim=0))
else:
return (torch.tensor([]).view(-1, boxes.shape[1]),
torch.tensor([]).view(-1),
torch.tensor([]).view(-1))
def wbc(
boxes: Tensor,
scores: Tensor,
weights: Tensor,
n_exp_preds: Tensor,
iou_thresh: float,
score_thresh: float,
use_area: bool = True,
missing_weight: float = 1.,
) -> Tuple[Tensor, Tensor]:
"""
Weighted box clustering
Args:
boxes: tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
scores: score for each box [N]
weights: additional weights for boxes [N]
n_exp_preds: expected number of predictions per box
iou_thresh: iou threshold for determining clusters of boxes which are
combined
score_thresh: minimum scores of boxes after consolidation
use_area: assigns higher weights to larger boxes based on
empirical observations indicating an increase in image
evidence from larger areas.
missing_weight: weight for score dampening when predictions are missing
Returns:
Tensor: consolidated boxes
Tensor: consolidated scores
"""
ious = box_iou(boxes, boxes)
if use_area:
areas = box_area(boxes)
weights = weights * areas
_, idx_pool = torch.sort(scores, descending=True)
new_boxes, new_scores = [], []
while idx_pool.nelement() > 0:
# build cluster
highest_scoring_id = idx_pool[0]
matches = torch.where(ious[highest_scoring_id][idx_pool] > iou_thresh)[0].flatten()
box_idx = idx_pool[matches]
# compute new scores
n_expected = n_exp_preds[box_idx].float().mean()
new_box, new_score = compute_cluster_consolidation(
boxes[box_idx], scores[box_idx],
weights=weights[box_idx],
ious=ious[highest_scoring_id][box_idx],
n_expected=n_expected,
n_found=len(box_idx),
missing_weight=missing_weight,
)
if new_score > score_thresh:
new_boxes.append(new_box)
new_scores.append(new_score)
# get all elements that were not matched and discard all others.
non_matches = torch.where(ious[highest_scoring_id][idx_pool] <= iou_thresh)[0].flatten()
idx_pool = idx_pool[non_matches]
if new_boxes:
return torch.stack(new_boxes, dim=0), torch.cat(new_scores, dim=0)
else:
return torch.tensor([]).view(-1, boxes.shape[1]).to(boxes), torch.tensor([]).view(-1).to(scores)
def compute_cluster_consolidation(
boxes: Tensor,
scores: Tensor,
weights: Tensor,
ious: Tensor,
n_expected: Tensor,
n_found: int,
missing_weight: float,
) -> Tuple[Tensor, Tensor]:
"""
Consolidate predictions of a single cluster
Args:
boxes: boxes of a single cluster (x1, y1, x2, y2, (z1, z2) [N, dims * 2]
scores: scores of a single cluster [N]
weights: weights for boxes of a single cluster [N]
ious: ious with recard to highest scoring box in a single cluster [N]
n_expected: expected number of predictions
n_found: number of predictions
missing_weight: weight for score dampening when predictions are missing
Returns:
Tensor: new boxes (x1, y1, x2, y2, (z1, z2) [N, dims * 2]
Tensor: new scores [N]
"""
# compute new score
match_score_weights = ious * weights
match_scores = match_score_weights * scores
n_missing_preds = torch.max(torch.tensor([0.], device=n_expected.device),
(n_expected - n_found).float())
denom = match_score_weights.sum() + n_missing_preds * match_score_weights.mean() * missing_weight
consolidated_score = match_scores.sum() / denom
consolidated_boxes = (boxes * match_scores.reshape(-1, 1)).sum(dim=0) / match_scores.sum()
return consolidated_boxes, consolidated_score
def compute_cluster_consolidation2(
boxes: Tensor,
scores: Tensor,
weights: Tensor,
ious: Tensor,
n_expected: Tensor,
n_found: int,
missing_weight: float,
) -> Tuple[Tensor, Tensor]:
"""
Consolidate predictions of a single cluster
Args:
boxes: boxes of a single cluster (x1, y1, x2, y2, (z1, z2) [N, dims * 2]
scores: scores of a single cluster [N]
weights: weights for boxes of a single cluster [N]
ious: ious with recard to highest scoring box in a single cluster [N]
n_expected: expected number of predictions
n_found: number of predictions
missing_weight: weight for score dampening when predictions are missing
Returns:
Tensor: new boxes (x1, y1, x2, y2, (z1, z2) [N, dims * 2]
Tensor: new scores [N]
"""
# select num expected predictions from ious & score weihted score
topk_score = ious * weights * scores
topk_weighted_scores, topk_idx = topk_score.topk(min(len(scores), int(n_expected)))
boxes = boxes[topk_idx]
scores = scores[topk_idx]
n_missing_preds = torch.max(torch.tensor([0.], device=n_expected.device),
(n_expected - n_found).float())
# weigh predictions with high ious higher, penalty term for missing predictions
consolidated_score = scores.mean() * (1 - missing_weight * n_missing_preds / n_expected)
consolidated_boxes = (boxes * topk_weighted_scores.reshape(-1, 1)).sum(dim=0) / topk_weighted_scores.sum()
return consolidated_boxes, consolidated_score
from nndet.inference.ensembler.base import BaseEnsembler, BaseEnsemblerType, OverlapMap
from nndet.inference.ensembler.detection import BoxEnsembler
from nndet.inference.ensembler.segmentation import SegmentationEnsembler
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from os import PathLike
from pathlib import Path
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, TypeVar
import torch
from nndet.io.load import save_pickle
from nndet.utils.tensor import to_numpy
from nndet.utils.info import maybe_verbose_iterable
class BaseEnsembler(ABC):
ID = "abstract"
def __init__(self,
properties: Dict[str, Any],
parameters: Dict[str, Any],
device: Optional[Union[torch.device, str]] = None,
**kwargs):
"""
Base class to containerize and ensemble the predictions of a single case.
Call :method:`process_batch` to add batched predictions of a case
to the ensembler and :method:`add_model` to signal the next model
if multiple models are used.
Args:
properties: properties of the patient/case (e.g. tranpose axes)
parameters: parameters for ensembling
device: device to use for internal computations
**kwargs: parameters for ensembling
Notes:
Call :method:`add_model` before adding predictions.
"""
self.model_current = None
self.model_results = {}
self.model_weights = {}
self.properties = properties
self.case_result: Optional[Dict] = None
self.parameters = parameters
self.parameters.update(kwargs)
if device is None:
self.device = torch.device("cpu")
elif isinstance(device, str):
self.device = torch.device(device)
elif isinstance(device, torch.device):
self.device = device
else:
raise ValueError(f"Wrong type {type(device)} for device argument.")
@classmethod
def from_case(cls,
case: Dict,
properties: Optional[Dict] = None,
parameters: Optional[Dict] = None,
**kwargs,
):
"""
Primary way to instantiate this class. Automatically extracts all
properties and uses a default set of parameters for ensembling.
Args:
case: case which is predicted
properties: Additional properties. Defaults to None.
parameters: Additional parameters. Defaults to None.
"""
return cls(properties=properties, parameters=parameters, **kwargs)
def add_model(self,
name: Optional[str] = None,
model_weight: Optional[float] = None,
) -> str:
"""
This functions signales the ensembler to add a new model for internal
processing
Args:
name: Name of the model. If None, uses counts the models.
model_weight: Optional weight for this model. Defaults to None.
"""
if name is None:
name = len(self.model_weights) + 1
if name in self.model_results:
raise ValueError(f"Invalid model name, model {name} is already present")
if model_weight is None:
model_weight = 1.0
self.model_weights[name] = model_weight
self.model_results[name] = defaultdict(list)
self.model_current = name
return name
@abstractmethod
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch
Args:
result: predictions to save and ensemble
batch: input batch used for predictions (for additional meta data)
Raises:
NotImplementedError: Overwrite this function in subclasses for the
specific use case.
Warnings:
Make sure to move cached values to the CPU after they have been
processed.
"""
raise NotImplementedError
@abstractmethod
@torch.no_grad()
def get_case_result(self, restore: bool = False) -> Dict[str, torch.Tensor]:
"""
Retrieve the results of a single case
Args:
restore: restores predictions in original image space
Raises:
NotImplementedError: Overwrite this function in subclasses for the
specific use case.
Returns:
Dict[str, torch.Tensor]: the result of a single case
"""
raise NotImplementedError
def update_parameters(self, **parameters: Dict):
"""
Update internal parameters used for ensembling the results
Args:
parameters: parameters to update
"""
self.parameters.update(parameters)
@classmethod
@abstractmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any], Dict[str, Sequence[Any]]]:
"""
Return a set of parameters which can be used to sweep ensembling
parameters in a postprocessing step
Returns:
Dict[str, Any]: default state to start with
Dict[str, Sequence[Any]]]: Defines the values to search for each
parameter
"""
raise NotImplementedError
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name
Args:
target_dir: folder to save result to
name: name of case
**kwargs: data to save
"""
kwargs["properties"] = self.properties
kwargs["parameters"] = self.parameters
kwargs["model_current"] = self.model_current
kwargs["model_results"] = self.model_results
kwargs["model_weights"] = self.model_weights
kwargs["case_result"] = self.case_result
with open(Path(target_dir) / f"{name}_{self.ID}.pt", "wb") as f:
torch.save(kwargs, f)
def load_state(self, base_dir: PathLike, case_id: str) -> Dict:
"""
Path to result file
"""
ckp = torch.load(str(Path(base_dir) / f"{case_id}_{self.ID}.pt"))
self._load(ckp)
return ckp
def _load(self, state: Dict):
for key, item in state.items():
setattr(self, key, item)
@classmethod
def from_checkpoint(cls, base_dir: PathLike, case_id: str):
ckp = torch.load(str(Path(base_dir) / f"{case_id}_{cls.ID}.pt"))
t = cls(
properties=ckp["properties"],
parameters=ckp["parameters"],
)
t._load(ckp)
return t
@classmethod
def get_case_ids(cls, base_dir: PathLike):
return [c.stem.rsplit(f"_{cls.ID}", 1)[0]
for c in Path(base_dir).glob(f"*_{cls.ID}.pt")]
class OverlapMap:
def __init__(self, data_shape: Sequence[int]):
"""
Handler for overlap map
Args:
data_shape: spatial dimensions of data (
no batch dim and no channel dim!)
"""
self.overlap_map: torch.Tensor = \
torch.zeros(*data_shape, requires_grad=False, dtype=torch.float)
def add_overlap(self, crop: Sequence[slice]):
"""
Increase values of :param:`self.overlap_map` inside of crop
Args:
crop: defines crop. Negative values are assumed to be outside
of the data and thus discarded
"""
# discard leading indexes which could be due to batches and channels
if len(crop) > self.overlap_map.ndim:
crop = crop[-self.overlap_map.ndim:]
# clip crop to data shape
slicer = []
for data_shape, crop_dim in zip(tuple(self.overlap_map.shape), crop):
start = max(0, crop_dim.start)
stop = min(data_shape, crop_dim.stop)
slicer.append(slice(start, stop, crop_dim.step))
self.overlap_map[slicer] += 1
def mean_num_overlap_of_box(self, box: Sequence[int]) -> float:
"""
Extract mean number of overlaps from a bounding box area
Args:
box: defines bounding box (x1, y1, x2, y2, (z1, z2))
Returns:
int: mean number of overlaps
"""
slicer = [slice(int(box[0]), int(box[2])), slice(int(box[1]), int(box[3]))]
if len(box) == 6:
slicer.append(slice(int(box[4]), int(box[5])))
return torch.mean(self.overlap_map[slicer].float()).item()
def mean_num_overlap_of_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""
Extract mean number of overlaps from a bounding box area
Args:
boxes: defines multiple bounding boxes (x1, y1, x2, y2, (z1, z2))
[N, dim * 2]
Returns:
Tensor: mean number of overlaps per box [N]
"""
return torch.tensor(
[self.mean_num_overlap_of_box(box) for box in boxes]).to(
dtype=torch.float, device=boxes.device)
def avg(self) -> torch.Tensor:
"""
Compute mean over all overlaps
"""
return self.overlap_map.float().median()
def restore_mean(self, val):
"""
Generate a new overlap map filled with the specified value
"""
self.overlap_map = torch.zeros_like(self.overlap_map)
self.overlap_map = float(val)
def extract_results(source_dir: PathLike,
target_dir: PathLike,
ensembler_cls: Callable,
restore: bool,
**params,
) -> None:
"""
Compute case result from ensembler and save it
Args:
source_dir: directory which contains the saved predictions/state from
the ensembler class
target_dir: directory to save results
ensembler_cls: ensembler class for prediction
restore: if true, the results are converted into the opriginal image
space
"""
Path(target_dir).mkdir(parents=True, exist_ok=True)
for case_id in maybe_verbose_iterable(ensembler_cls.get_case_ids(source_dir)):
ensembler = ensembler_cls.from_checkpoint(base_dir=source_dir, case_id=case_id)
ensembler.update_parameters(**params)
pred = to_numpy(ensembler.get_case_result(restore=restore))
save_pickle(pred, Path(target_dir) / f"{case_id}_{ensembler_cls.ID}.pkl")
BaseEnsemblerType = TypeVar('BaseEnsemblerType', bound=BaseEnsembler)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Hashable, Union
import torch
import numpy as np
from scipy.stats import norm
from torch import Tensor
from loguru import logger
from nndet.inference.detection.model import batched_weighted_nms_model
from nndet.inference.detection import batched_nms_model, batched_nms_ensemble, \
batched_wbc_ensemble, wbc_nms_no_label_ensemble
from nndet.inference.ensembler.base import BaseEnsembler, OverlapMap
from nndet.inference.restore import restore_detection
from nndet.detection.boxes import box_center, clip_boxes_to_image, remove_small_boxes
from nndet.detection.boxes.merging import GreedyIoUBoxMerger, VoteLabelGreedyIoUBoxMerger
from nndet.utils.tensor import cat, to_device, to_dtype
class BoxEnsembler(BaseEnsembler):
ID = "boxes"
def __init__(self,
properties: Dict[str, Any],
parameters: Dict[str, Any],
box_key: str = 'pred_boxes',
score_key: str = 'pred_scores',
label_key: str = 'pred_labels',
data_key: str = 'data',
device: Optional[Union[torch.device, str]] = None,
**kwargs):
"""
Ensemble bounding box detections from tta and multiple models
Args:
properties: properties of the patient/case (e.g. tranpose axes)
parameters: parameters for ensembling
box_key: key where boxes are located inside prediction dict
score_key: key where scores are located inside prediction dict
label_key: key where labels are located inside prediction dict
data_key: key where data is located inside batch dict
device: device to use for internal computations
kwargs: passed to super class
"""
super().__init__(
properties=properties,
parameters=parameters,
device=device,
**kwargs,
)
# parameters to access information from predictions and batches
self.data_key = data_key
self.score_key = score_key
self.label_key = label_key
self.box_key = box_key
self.overlap_map = OverlapMap(tuple(self.properties["shape"]))
@classmethod
def from_case(cls,
case: Dict,
properties: Dict,
parameters: Optional[Dict] = None,
box_key: str = 'pred_boxes',
score_key: str = 'pred_scores',
label_key: str = 'pred_labels',
data_key: str = 'data',
device: Optional[Union[torch.device, str]] = None,
**kwargs,
):
"""
Primary way to instantiate this class. Automatically extracts all
properties and uses a default set of parameters for ensembling.
Args:
case: case which is predicted.
properties: Additional properties.
Required keys:
`transpose_backward`
`spacing_after_resampling`
`crop_bbox`
parameters: Additional parameters. Defaults to None.
box_key: key where boxes are located inside prediction dict
score_key: key where scores are located inside prediction dict
label_key: key where labels are located inside prediction dict
data_key: key where data is located inside batch dict
device: device to use for internal computations
"""
_parameters = cls.get_default_parameters()
_parameters.update(parameters)
_properties = {
"shape": case[data_key].shape[1:], # remove channel dim
"transpose_backward": properties["transpose_backward"],
"original_spacing": properties["original_spacing"],
"spacing_after_resampling": properties["spacing_after_resampling"],
"crop_bbox": properties["crop_bbox"],
"original_size_of_raw_data": properties["original_size_of_raw_data"],
"itk_origin": properties["itk_origin"],
"itk_spacing": properties["itk_spacing"],
"itk_direction": properties["itk_direction"],
}
return cls(
properties=_properties,
parameters=_parameters,
box_key=box_key,
score_key=score_key,
label_key=label_key,
data_key=data_key,
device=device,
**kwargs,
)
@classmethod
def get_default_parameters(cls):
"""
Generate default parameters for instantiation
Returns:
Dict:
`model_iou`: IoU for model nms function
`model_nms_fn`: function to use for model NMS
`model_topk`: number of predictions with the highest
probability to keep
`ensemble_iou`: IoU for ensembling the predictions of multiple
models
`ensemble_nms_fn`: ensemble predictions from multiple
models
`ensemble_nms_topk`: number of predictions with the highest
probability to keep
`ensemble_remove_small_boxes`: minimum size of the box
`ensemble_score_thresh`: minimum probability
"""
return {
# single model
"model_iou": 0.1,
"model_nms_fn": batched_nms_model,
"model_score_thresh": 0.0,
"model_topk": 1000,
"model_detections_per_image": 100,
# ensemble multiple models
"ensemble_iou": 0.5,
"ensemble_nms_fn": batched_wbc_ensemble,
"ensemble_topk": 1000,
"remove_small_boxes": 1e-2,
"ensemble_score_thresh": 0.0,
}
def postprocess_image(self,
boxes: torch.Tensor,
probs: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
shape: Optional[Tuple[int]] = None
) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""
Postprocessing of a single image
select topk predictions -> score threshold -> clipping -> \
remove small boxes -> nms
Args:
boxes: predicted deltas for proposals [N, dim * 2]
probs: predicted logits for boxes [N]
labels: predicted labels for boxes [N]
weights: weight for each box [N]
Returns:
torch.Tensor: postprocessed boxes
torch.Tensor: postprocessed probs
torch.Tensor: postprocessed labels
torch.Tensor: postprocessed weights
"""
p_sorted, idx_sorted = probs.sort(descending=True)
idx_sorted = idx_sorted[:self.parameters["model_topk"]]
p_sorted = p_sorted[:self.parameters["model_topk"]]
keep_idxs = p_sorted > self.parameters["model_score_thresh"]
idx_sorted = idx_sorted[keep_idxs]
b, p, l, w = boxes[idx_sorted], probs[idx_sorted], labels[idx_sorted], weights[idx_sorted]
b = clip_boxes_to_image(b, shape)
# After clipping we could have boxes with volume 0 which we definitely
# need to remove because of the IoU computation
keep = remove_small_boxes(
b, min_size=self.parameters["remove_small_boxes"])
b, p, l, w = b[keep], p[keep], l[keep], w[keep]
_boxes, _probs, _labels, _weights = self.parameters["model_nms_fn"](
boxes=b, scores=p, labels=l, weights=w,
iou_thresh=self.parameters["model_iou"],
)
# predictions are sorted
_boxes = _boxes[:self.parameters.get("model_detections_per_image", 1000)]
_probs = _probs[:self.parameters.get("model_detections_per_image", 1000)]
_labels = _labels[:self.parameters.get("model_detections_per_image", 1000)]
_weights = _weights[:self.parameters.get("model_detections_per_image", 1000)]
return _boxes, _probs, _labels, _weights
@staticmethod
def _apply_offsets_to_boxes(boxes: List[Tensor],
tile_offset: Sequence[Sequence[int]],
) -> List[Tensor]:
"""
Apply offset to bounding boxes to position them correctly inside
the whole case
Args:
boxes: predicted boxes [N, dims * 2]
[x1, y1, x2, y2, (z1, z2))
tile_offset: defines offset for each tile
Returns:
List[Tensor]: bounding boxes with respect to origin of whole case
"""
offset_boxes = []
for img_boxes, offset in zip(boxes, tile_offset):
if img_boxes.nelement() == 0:
offset_boxes.append(img_boxes)
continue
offset = Tensor(offset).to(img_boxes)
_boxes = img_boxes.clone()
_boxes[:, 0] += offset[0]
_boxes[:, 1] += offset[1]
_boxes[:, 2] += offset[0]
_boxes[:, 3] += offset[1]
if img_boxes.shape[1] == 6:
_boxes[:, 4] += offset[2]
_boxes[:, 5] += offset[2]
offset_boxes.append(_boxes)
return offset_boxes
def restore_prediction(self, boxes: Tensor):
"""
Restore predictions in the original image space
Args:
boxes: predicted boxes [N, dims * 2] (x1, y1, x2, y2, (z1, z2))
Returns:
Tensor: boxes in original image space [N, dims * 2]
(x1, y1, x2, y2, (z1, z2))
"""
_old_dtype = boxes.dtype
boxes_np = restore_detection(
boxes.detach().cpu().numpy(),
transpose_backward=self.properties["transpose_backward"],
original_spacing=self.properties["original_spacing"],
spacing_after_resampling=self.properties["spacing_after_resampling"],
crop_bbox=self.properties["crop_bbox"],
)
boxes = torch.from_numpy(boxes_np).to(dtype=_old_dtype)
return boxes
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name
Args:
target_dir: folder to save result to
name: name of case
Notes:
The device is not saved inside the checkpoint and everything
will be loaded on the CPU.
"""
super().save_state(
target_dir=target_dir,
name=name,
score_key=self.score_key,
label_key=self.label_key,
box_key=self.box_key,
data_key=self.data_key,
overlap_map=self.overlap_map,
**kwargs,
)
@classmethod
def from_checkpoint(cls, base_dir: PathLike, case_id: str, **kwargs):
ckp = torch.load(str(Path(base_dir) / f"{case_id}_{cls.ID}.pt"))
t = cls(
properties=ckp["properties"],
parameters=ckp["parameters"],
box_key=ckp["box_key"],
score_key=ckp["score_key"],
label_key=ckp["label_key"],
data_key=ckp["data_key"],
**kwargs
)
t._load(ckp)
return t
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any],
Dict[str, Sequence[Any]]]:
# iou_threshs = np.linspace(0.0, 0.8, 9)
iou_threshs = np.linspace(0.0, 0.5, 6)
iou_threshs[0] = 1e-5
small_boxes_thresh = np.linspace(2., 7., 6)
param_sweep = {
# ensemble multiple models
"ensemble_iou": iou_threshs,
"model_score_thresh": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
# "remove_small_boxes": small_boxes_thresh,
}
return cls.get_default_parameters(), param_sweep
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
(the boxes are clipped to the case size in the ensembling step)
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.box_key`: List[Tensor]: predicted boxes (relative
to patch coordinates)
`self.score_key` List[Tensor]: score for each tensor
`self.label_key`: List[Tensor] label prediction for each box
batch: input batch for detector
`tile_origin: origin of crop with respect to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
Warnings:
Make sure to move cached values to the CPU after they have been
processed.
"""
tile_origins = [to for to in zip(*batch["tile_origin"])]
tile_size = batch[self.data_key].shape[2:]
boxes = []
scores = []
labels = []
for b, s, l in zip(result[self.box_key], result[self.score_key], result[self.label_key]):
_boxes, _scores, _labels, _ = self.postprocess_image(
boxes=b.float(),
probs=s.float(),
labels=l.float(),
weights=torch.ones_like(s).float(),
shape=tuple(tile_size),
)
boxes.append(_boxes.cpu())
scores.append(_scores.cpu())
labels.append(_labels.cpu())
centers = [box_center(img_boxes) if img_boxes.numel() > 0 else Tensor([]).to(img_boxes)
for img_boxes in boxes]
weights = [self._get_box_in_tile_weight(c, tile_size) for c in centers]
weights = [w * self.model_weights[self.model_current] for w in weights]
boxes = self._apply_offsets_to_boxes(boxes, tile_origins)
self.model_results[self.model_current]["boxes"].extend(boxes)
self.model_results[self.model_current]["scores"].extend(scores)
self.model_results[self.model_current]["labels"].extend(labels)
self.model_results[self.model_current]["weights"].extend(weights)
crops_reshaped = list(zip(*batch["crop"]))
self.model_results[self.model_current]["crops"].extend(crops_reshaped)
for crop in crops_reshaped:
self.overlap_map.add_overlap(crop)
@staticmethod
def _get_box_in_tile_weight(box_centers: Tensor,
tile_size: Sequence[int],
) -> Tensor:
"""
Assign boxes at the corners of tiles a lower weight (weight
is drawn form a scaled normal distribution)
Args:
box_centers: center predicted box [N, dims]
tile_size: size the of patch/tile
Returns:
Tensor: weight for each bounding box [N]
"""
if box_centers.numel() > 0:
all_weights = []
centers_np = box_centers.detach().cpu().numpy()
for center_np in centers_np:
weight = np.mean([
norm.pdf(bc, loc=ps, scale=ps * 0.8) * np.sqrt(2 * np.pi) * ps * 0.8
for bc, ps in zip(center_np, np.array(tile_size) / 2)])
all_weights.append([weight])
return torch.from_numpy(np.concatenate(all_weights)).to(box_centers)
else:
return Tensor([]).to(box_centers)
@torch.no_grad()
def get_case_result(self,
restore: bool = False,
names: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Tensor]:
"""
Process all the batches and models and create the final prediction
Args:
restore: restore prediction in the original image space
names: name of the models to use. By default all models are used.
Returns:
Dict: final result
`pred_boxes`: predicted box locations
[N, dims * 2] (x1, y1, x2, y2, (z1, z2))
`pred_scores`: predicted probability per box [N]
`pred_labels`: predicted label per box [N]
`restore`: indicate whether predictions were restored in
original image space
`original_size_of_raw_data`: image shape befor preprocessing
`itk_origin`: itk origin of image before preprocessing
`itk_spacing`: itk spacing of image before preprocessing
`itk_direction`: itk direction of image before preprocessing
"""
if names is None:
names = list(self.model_results.keys())
boxes, probs, labels, weights = [], [], [], []
for name in names:
_boxes, _probs, _labels, _weights = self.process_model(name)
boxes.append(_boxes)
probs.append(_probs)
labels.append(_labels)
weights.append(_weights)
boxes, probs, labels = self.process_ensemble(
boxes=boxes, probs=probs, labels=labels,
weights=weights,
)
if restore:
boxes = self.restore_prediction(boxes)
return {
"pred_boxes": boxes,
"pred_scores": probs,
"pred_labels": labels,
"restore": restore,
"original_size_of_raw_data": self.properties["original_size_of_raw_data"],
"itk_origin": self.properties["itk_origin"],
"itk_spacing": self.properties["itk_spacing"],
"itk_direction": self.properties["itk_direction"],
}
def process_model(self, name: Hashable) ->\
Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Process the output of a single model on the whole scan
topk candidates -> nms
Args:
name: name of model to process
Returns:
Tensor: filtered boxes
Tensor: filtered probs
Tensor: filtered labels
idx: indices kept from original ordered data
"""
# concatenate batches
boxes = cat(self.model_results[name]["boxes"], dim=0)
probs = cat(self.model_results[name]["scores"], dim=0)
labels = cat(self.model_results[name]["labels"], dim=0)
weights = cat(self.model_results[name]["weights"], dim=0)
return boxes, probs, labels, weights
def process_ensemble(self, boxes: List[Tensor], probs: List[Tensor],
labels: List[Tensor], weights: List[Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble predictions from multiple models
Args:
boxes: predicted boxes List[[N, dims * 2]]
(x1, y1, x2, y2, (z1, z2))
probs: predicted probabilities List[[N]]
labels: predicted label List[[N]]
weights: additional weight List[[N]]
Returns:
Tensor: ensembled box predictions
Tensor: ensembled probabilities
Tensor: ensembled labels
"""
boxes = cat(boxes, dim=0)
probs = cat(probs, dim=0)
labels = cat(labels, dim=0)
weights = cat(weights, dim=0)
_, idx = probs.sort(descending=True)
idx = idx[:self.parameters["ensemble_topk"]]
boxes = boxes[idx]
probs = probs[idx]
labels = labels[idx]
weights = weights[idx]
n_exp_preds = self.overlap_map.mean_num_overlap_of_boxes(boxes)
boxes, probs, labels = self.parameters["ensemble_nms_fn"](
boxes, probs, labels,
weights=weights,
iou_thresh=self.parameters["model_iou"],
n_exp_preds=n_exp_preds,
score_thresh=self.parameters["ensemble_score_thresh"],
)
return boxes.cpu(), probs.cpu(), labels.cpu()
class BoxEnsemblerLW(BoxEnsembler):
"""
Uses different computation for box weight, much faster than box ensembler.
"""
@staticmethod
def _get_box_in_tile_weight(box_centers: Tensor,
tile_size: Sequence[int],
) -> Tensor:
"""
Assign boxes near the corner a lower weight.
The middle has a plateau with weight one, starting from patchsize / 2
the weights decreases linearly until 0.5 is reached.
Args:
box_centers: center predicted box [N, dims]
tile_size: size the of patch/tile
Returns:
Tensor: weight for each bounding box [N]
"""
plateau_length = 0.5 # adjust width of plateau and min weight
if box_centers.numel() > 0:
tile_center = torch.tensor(tile_size).to(box_centers) / 2. # [dims]
max_dist = tile_center.norm(p=2) # [1]
boxes_dist = (box_centers - tile_center[None]).norm(p=2, dim=1) # [N]
weight = -(boxes_dist / max_dist - plateau_length).clamp_(min=0) + 1
return weight
else:
return Tensor([]).to(box_centers)
class BoxEnsemblerFastest(BoxEnsemblerLW):
"""
Uses the fastest but not necessarily most precise box ensembling strategy
Only save top `num_reduced_cache` boxes for ensembling
Uses a linear box weight
Uses the mean over the whole overlap map. Depending on overlap
and patch stride this is not correct.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reduced_cache = False
self.num_reduced_cache = 8000
self.overlap_map_mean = None
@classmethod
def get_default_parameters(cls):
"""
Generate default parameters for instantiation
Returns:
Dict:
`model_iou`: IoU for model nms function
`model_nms_fn`: function to use for model NMS
`model_topk`: number of predictions with the highest
probability to keep
`ensemble_iou`: IoU for ensembling the predictions of multiple
models
`ensemble_nms_fn`: ensemble predictions from multiple
models
`ensemble_nms_topk`: number of predictions with the highest
probability to keep
`ensemble_remove_small_boxes`: minimum size of the box
`ensemble_score_thresh`: minimum probability
"""
return {
# single model
"model_iou": 0.1,
"model_nms_fn": batched_nms_model,
"model_score_thresh": 0.1,
"model_topk": 1000,
"model_detections_per_image": 1000,
# ensemble multiple models
"ensemble_iou": 0.5,
"ensemble_nms_fn": batched_wbc_ensemble,
"ensemble_topk": 1000,
"remove_small_boxes": 1e-2,
"ensemble_score_thresh": 0.0,
}
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any],
Dict[str, Sequence[Any]]]:
iou_threshs = np.linspace(0.0, 0.5, 6)
iou_threshs[0] = 1e-5
small_boxes_thresh = [1e-2] + np.linspace(2., 7., 6).tolist()
param_sweep = {
# single model
"model_iou": iou_threshs,
# ensemble multiple models
"ensemble_iou": iou_threshs,
"remove_small_boxes": small_boxes_thresh,
"model_score_thresh": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
}
return cls.get_default_parameters(), param_sweep
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
(the boxes are clipped to the case size in the ensembling step)
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.box_key`: List[Tensor]: predicted boxes (relative
to patch coordinates)
`self.score_key` List[Tensor]: score for each tensor
`self.label_key`: List[Tensor] label prediction for each box
batch: input batch for detector
`tile_origin: origin of crop with respect to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
"""
if self.reduced_cache:
logger.warning("Ensembler was already reduced, need to rerun reduce_cache "
"later and restore overlap map with proxy mean.")
self.overlap_map.restore_mean(self.overlap_map_mean)
self.reduced_cache = False
boxes = [r.half().cpu() for r in result[self.box_key]]
scores = [r.half().cpu() for r in result[self.score_key]]
labels = [r.half().cpu() for r in result[self.label_key]]
centers = [box_center(img_boxes) if img_boxes.numel() > 0 else Tensor([]).to(img_boxes)
for img_boxes in boxes]
tile_origins = [to for to in zip(*batch["tile_origin"])]
tile_size = batch[self.data_key].shape[2:]
weights = [self._get_box_in_tile_weight(c, tile_size) for c in centers]
weights = [w * self.model_weights[self.model_current] for w in weights]
boxes = self._apply_offsets_to_boxes(boxes, tile_origins)
self.model_results[self.model_current]["boxes"].extend(boxes)
self.model_results[self.model_current]["scores"].extend(scores)
self.model_results[self.model_current]["labels"].extend(labels)
self.model_results[self.model_current]["weights"].extend(weights)
crops_reshaped = list(zip(*batch["crop"]))
self.model_results[self.model_current]["crops"].extend(crops_reshaped)
for crop in crops_reshaped:
self.overlap_map.add_overlap(crop)
@staticmethod
def _get_box_in_tile_weight(box_centers: Tensor,
tile_size: Sequence[int],
) -> Tensor:
"""
Assign boxes near the corner a lower weight.
The middle has a plateau with weight one, starting from patchsize / 2
the weights decreases linearly until 0.5 is reached.
Args:
box_centers: center predicted box [N, dims]
tile_size: size the of patch/tile
Returns:
Tensor: weight for each bounding box [N]
"""
plateau_length = 0.5 # adjust width of plateau and min weight
if box_centers.numel() > 0:
tile_center = torch.tensor(tile_size).to(box_centers) / 2. # [dims]
max_dist = tile_center.norm(p=2) # [1]
boxes_dist = (box_centers - tile_center[None]).norm(p=2, dim=1) # [N]
weight = -(boxes_dist / max_dist - plateau_length).float().clamp_(min=0).half() + 1
return weight
else:
return Tensor([]).to(box_centers).half()
def process_model(self,
name: Hashable,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Process the output of a single model on the whole scan
topk candidates -> nms
Args:
name: name of model to process
Returns:
Tensor: processed boxes
Tensor: processed probs
Tensor: processed labels
Tensor: processed weights
"""
boxes = to_device(self.model_results[name]["boxes"], device=self.device)
probs = to_device(self.model_results[name]["scores"], device=self.device)
labels = to_device(self.model_results[name]["labels"], device=self.device)
weights = to_device(self.model_results[name]["weights"], device=self.device)
model_boxes = []
model_probs = []
model_labels = []
model_weights = []
for b, p, l, w in zip(boxes, probs, labels, weights):
if b.numel() > 0:
_b, _p, _l, _w = self.postprocess_image(
boxes=b.float(),
probs=p.float(),
labels=l.float(),
weights=w.float(),
shape=tuple(self.properties["shape"]),
)
model_boxes.append(_b)
model_probs.append(_p)
model_labels.append(_l)
model_weights.append(_w)
return cat(model_boxes), cat(model_probs), cat(model_labels), cat(model_weights)
def process_ensemble(self,
boxes: List[Tensor],
probs: List[Tensor],
labels: List[Tensor],
weights: List[Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble predictions from multiple models
Args:
boxes: predicted boxes List[[N, dims * 2]]
(x1, y1, x2, y2, (z1, z2))
probs: predicted probabilities List[[N]]
labels: predicted label List[[N]]
weights: additional weight List[[N]]
Returns:
Tensor: ensembled box predictions
Tensor: ensembled probabilities
Tensor: ensembled labels
"""
boxes = cat(boxes, dim=0)
probs = cat(probs, dim=0)
labels = cat(labels, dim=0)
weights = cat(weights, dim=0)
_, idx = probs.sort(descending=True)
idx = idx[:self.parameters["ensemble_topk"]]
boxes = boxes[idx]
probs = probs[idx]
labels = labels[idx]
weights = weights[idx]
n_exp_preds = self.overlap_map_mean.expand(len(boxes)).to(boxes)
boxes, probs, labels = self.parameters["ensemble_nms_fn"](
boxes, probs, labels,
weights=weights,
iou_thresh=self.parameters["model_iou"],
n_exp_preds=n_exp_preds,
score_thresh=self.parameters["ensemble_score_thresh"],
)
return boxes.cpu(), probs.cpu(), labels.cpu()
@torch.no_grad()
def get_case_result(self,
restore: bool = False,
names: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Tensor]:
"""
Process all the batches and models and create the final prediction
Args:
restore: restore prediction in the original image space
names: name of the models to use. By default all models are used.
Returns:
Dict: final result
`pred_boxes`: predicted box locations
[N, dims * 2] (x1, y1, x2, y2, (z1, z2))
`pred_scores`: predicted probability per box [N]
`pred_labels`: predicted label per box [N]
`restore`: indicate whether predictions were restored in
original image space
`original_size_of_raw_data`: image shape befor preprocessing
`itk_origin`: itk origin of image before preprocessing
`itk_spacing`: itk spacing of image before preprocessing
`itk_direction`: itk direction of image before preprocessing
"""
self.reduce_cache()
return super().get_case_result(restore=restore, names=names)
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name. Before saving the state, the cache will
be reduced to a predefined number of predictions to for memory
and computational reasons
Args:
target_dir: folder to save result to
name: name of case
Notes:
The device is not saved inside the checkpoint and everything
will be loaded on the CPU.
"""
self.reduce_cache()
return BaseEnsembler.save_state(
self,
target_dir=target_dir,
name=name,
reduced_cache=self.reduced_cache,
score_key=self.score_key,
label_key=self.label_key,
box_key=self.box_key,
data_key=self.data_key,
overlap_map_mean=self.overlap_map_mean,
**kwargs,
)
def reduce_cache(self):
"""
Only save a subset of all boxes for further evaluations
"""
if not self.reduced_cache:
self.reduced_cache = True
# we use the mean here to save time ...
self.overlap_map_mean = self.overlap_map.avg()
for model in self.model_results.keys():
batch_idx = self.build_batch_indices(self.model_results[model]["scores"])
boxes = cat(self.model_results[model]["boxes"])
probs = cat(self.model_results[model]["scores"])
labels = cat(self.model_results[model]["labels"])
weights = cat(self.model_results[model]["weights"])
if len(probs) > self.num_reduced_cache:
_, idx_sorted = probs.sort(descending=True)
idx_sorted = idx_sorted[:self.num_reduced_cache]
batch_idx_keep = [[b for b in bix if b in idx_sorted] for bix in batch_idx]
assert len(batch_idx_keep) == len(self.model_results[model]["scores"])
self.model_results[model]["boxes"] = [boxes[i] for i in batch_idx_keep]
self.model_results[model]["scores"] = [probs[i] for i in batch_idx_keep]
self.model_results[model]["labels"] = [labels[i] for i in batch_idx_keep]
self.model_results[model]["weights"] = [weights[i] for i in batch_idx_keep]
@staticmethod
def build_batch_indices(b: Sequence[Tensor]) -> List[List[int]]:
idx = []
num_elem = 0
for _b in b:
if _b.numel() > 0:
additional_elem = len(_b)
idx.append(list(range(num_elem, num_elem + additional_elem)))
num_elem += additional_elem
else:
idx.append([])
return idx
class BoxEnsemblerSelective(BoxEnsembler):
def __init__(self,
properties: Dict[str, Any],
parameters: Dict[str, Any],
box_key: str = 'pred_boxes',
score_key: str = 'pred_scores',
label_key: str = 'pred_labels',
data_key: str = 'data',
device: Optional[Union[torch.device, str]] = None,
**kwargs,
):
"""
Ensemble bounding box detections from tta and multiple models
This uses a different ensembling strategy which is faster and allows
for model IoU optimization.
Args:
properties: properties of the patient/case (e.g. tranpose axes)
parameters: parameters for ensembling
box_key: key where boxes are located inside prediction dict
score_key: key where scores are located inside prediction dict
label_key: key where labels are located inside prediction dict
data_key: key where data is located inside batch dict
device: device to use for internal computations
kwargs: passed to super class
"""
super().__init__(
properties=properties,
parameters=parameters,
device=device,
box_key=box_key,
score_key=score_key,
label_key=label_key,
data_key=data_key,
**kwargs,
)
self.overlap_map = None
@classmethod
def get_default_parameters(cls):
"""
Generate default parameters for instantiation
Returns:
Dict:
`model_iou`: IoU for model nms function
`model_nms_fn`: function to use for model NMS
`model_topk`: number of predictions with the highest
probability to keep
`ensemble_iou`: IoU for ensembling the predictions of multiple
models
`ensemble_nms_fn`: ensemble predictions from multiple
models
`ensemble_nms_topk`: number of predictions with the highest
probability to keep
`ensemble_remove_small_boxes`: minimum size of the box
`ensemble_score_thresh`: minimum probability
"""
return {
# single model
"model_iou": 0.1,
"model_nms_fn": batched_weighted_nms_model,
"model_score_thresh": 0.0,
"model_topk": 1000,
"model_detections_per_image": 100,
# ensemble multiple models
"ensemble_iou": 0.5,
"ensemble_nms_fn": batched_wbc_ensemble,
"ensemble_topk": 1000,
"remove_small_boxes": 1e-2,
"ensemble_score_thresh": 0.0,
}
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any],
Dict[str, Sequence[Any]]]:
# iou_threshs = np.linspace(0.0, 0.8, 9)
iou_threshs = np.linspace(0.0, 0.5, 6)
iou_threshs[0] = 1e-5
small_boxes_thresh = [1e-2] + np.linspace(2., 7., 6).tolist()
param_sweep = {
# single model
"model_iou": iou_threshs,
"model_nms_fn": [
batched_weighted_nms_model,
batched_nms_model,
],
# ensemble multiple models
"ensemble_iou": iou_threshs,
"model_score_thresh": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"remove_small_boxes": small_boxes_thresh,
}
return cls.get_default_parameters(), param_sweep
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
(the boxes are clipped to the case size in the ensembling step)
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.box_key`: List[Tensor]: predicted boxes (relative
to patch coordinates)
`self.score_key` List[Tensor]: score for each tensor
`self.label_key`: List[Tensor] label prediction for each box
batch: input batch for detector
`tile_origin: origin of crop with respect to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
"""
boxes = [r.float().cpu() for r in result[self.box_key]]
scores = [r.float().cpu() for r in result[self.score_key]]
labels = [r.float().cpu() for r in result[self.label_key]]
centers = [box_center(img_boxes) if img_boxes.numel() > 0 else Tensor([]).to(img_boxes)
for img_boxes in boxes]
tile_origins = [to for to in zip(*batch["tile_origin"])]
tile_size = batch[self.data_key].shape[2:]
weights = [self._get_box_in_tile_weight(c, tile_size) for c in centers]
weights = [w * self.model_weights[self.model_current] for w in weights]
boxes = self._apply_offsets_to_boxes(boxes, tile_origins)
self.model_results[self.model_current]["boxes"].extend(boxes)
self.model_results[self.model_current]["scores"].extend(scores)
self.model_results[self.model_current]["labels"].extend(labels)
self.model_results[self.model_current]["weights"].extend(weights)
# self.model_results[self.model_current]["crops"].extend(
# list(zip(*batch["crop"])))
@staticmethod
def _get_box_in_tile_weight(box_centers: Tensor,
tile_size: Sequence[int],
) -> Tensor:
"""
Assign boxes near the corner a lower weight.
The midle has a plateau with weight one, starting from patchsize / 2
the weights decreases linearly until 0.5 is reached.
Args:
box_centers: center predicted box [N, dims]
tile_size: size the of patch/tile
Returns:
Tensor: weight for each bounding box [N]
"""
plateau_length = 0.5 # adjust width of plateau and min weight
if box_centers.numel() > 0:
tile_center = torch.tensor(tile_size).to(box_centers) / 2. # [dims]
max_dist = tile_center.norm(p=2) # [1]
boxes_dist = (box_centers - tile_center[None]).norm(p=2, dim=1) # [N]
weight = -(boxes_dist / max_dist - plateau_length).clamp_(min=0) + 1
return weight
else:
return Tensor([]).to(box_centers)
def process_model(self, name: Hashable) ->\
Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Process the output of a single model on the whole scan
topk candidates -> nms
Args:
name: name of model to process
Returns:
Tensor: processed boxes
Tensor: processed probs
Tensor: processed labels
Tensor: processed weights
"""
# collect predictions on whole case and apply postprocessing
boxes = cat(self.model_results[name]["boxes"]).to(self.device)
probs = cat(self.model_results[name]["scores"]).to(self.device)
labels = cat(self.model_results[name]["labels"]).to(self.device)
weights = cat(self.model_results[name]["weights"]).to(self.device)
return self.postprocess_image(
boxes=boxes,
probs=probs,
labels=labels,
weights=weights,
shape=tuple(self.properties["shape"]),
)
def process_ensemble(self, boxes: List[Tensor], probs: List[Tensor],
labels: List[Tensor], weights: List[Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Ensemble predictions from multiple models
Args:
boxes: predicted boxes List[[N, dims * 2]]
(x1, y1, x2, y2, (z1, z2))
probs: predicted probabilities List[[N]]
labels: predicted label List[[N]]
weights: additional weight List[[N]]
Returns:
Tensor: ensembled box predictions
Tensor: ensembled probabilities
Tensor: ensembled labels
"""
num_models = len(boxes)
boxes = cat(boxes, dim=0)
probs = cat(probs, dim=0)
labels = cat(labels, dim=0)
weights = cat(weights, dim=0)
_, idx = probs.sort(descending=True)
idx = idx[:self.parameters["ensemble_topk"]]
boxes = boxes[idx]
probs = probs[idx]
labels = labels[idx]
weights = weights[idx]
n_exp_preds = torch.tensor([num_models] * len(boxes)).to(boxes)
boxes, probs, labels = self.parameters["ensemble_nms_fn"](
boxes, probs, labels,
weights=weights,
iou_thresh=self.parameters["ensemble_iou"],
n_exp_preds=n_exp_preds,
score_thresh=self.parameters["ensemble_score_thresh"],
)
return boxes.cpu(), probs.cpu(), labels.cpu()
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name.
This version only saves the topk model predictions to speed
up loading.
Args:
target_dir: folder to save result to
name: name of case
Notes:
The device is not saved inside the checkpoint and everything
will be loaded on the CPU.
"""
for model in self.model_results.keys():
boxes = cat(self.model_results[model]["boxes"])
probs = cat(self.model_results[model]["scores"])
labels = cat(self.model_results[model]["labels"])
weights = cat(self.model_results[model]["weights"])
if len(probs) > self.parameters["model_topk"]:
_, idx_sorted = probs.sort(descending=True)
idx_sorted = idx_sorted[:self.parameters["model_topk"]]
self.model_results[model]["boxes"] = boxes[idx_sorted]
self.model_results[model]["scores"] = probs[idx_sorted]
self.model_results[model]["labels"] = labels[idx_sorted]
self.model_results[model]["weights"] = weights[idx_sorted]
return super().save_state(target_dir=target_dir, name=name, **kwargs)
class BoxEnsemblerSelective2D(BoxEnsemblerSelective):
"""
Box ensembler for 2d predictions
Can be used to process 2d predictions of a 3d volume.
"""
@classmethod
def get_default_parameters(cls):
"""
Generate default parameters for instantiation
Returns:
Dict:
`model_iou`: IoU for model nms function
`model_nms_fn`: function to use for model NMS
`model_topk`: number of predictions with the highest
probability to keep
`ensemble_iou`: IoU for ensembling the predictions of multiple
models
`ensemble_nms_fn`: ensemble predictions from multiple
models
`ensemble_nms_topk`: number of predictions with the highest
probability to keep
`ensemble_remove_small_boxes`: minimum size of the box
`ensemble_score_thresh`: minimum probability
"""
params = super().get_default_parameters()
params["model_topk"] = 10000
params["model_detections_per_image"] = 4000
params["model_score_thresh"] = 0.3
params["ensemble_topk"] = 10000
params["track_iou"] = 0.5
params["track_neighbor_slices"] = 1
params["track_merger_cls"] = VoteLabelGreedyIoUBoxMerger
params["track_remove_small_boxes"] = 0
return params
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any],
Dict[str, Sequence[Any]]]:
iou_threshs = np.linspace(0.0, 0.5, 6)
iou_threshs[0] = 1e-5
track_ious = np.linspace(0.3, 0.8, 6)
param_sweep = {
# single model
"model_iou": iou_threshs,
"model_nms_fn": [
batched_weighted_nms_model,
batched_nms_model,
],
# ensemble multiple models
"ensemble_iou": iou_threshs,
"track_iou": track_ious,
"track_neighbor_slices": [1, 2, 3, 4],
"track_merger_cls": [
GreedyIoUBoxMerger,
VoteLabelGreedyIoUBoxMerger,
],
"track_remove_small_boxes": [0, 1, 2, 3, 4],
"model_score_thresh": [0.2, 0.3, 0.4, 0.5, 0.6],
}
return cls.get_default_parameters(), param_sweep
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
(the boxes are clipped to the case size in the ensembling step)
This already expands the box into the positive z direction.
Expansion into the negative z dimension is done after the tracking
step.
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.box_key`: List[Tensor]: predicted boxes (relative
to patch coordinates)
`self.score_key` List[Tensor]: score for each tensor
`self.label_key`: List[Tensor] label prediction for each box
batch: input batch for detector
`tile_origin: origin of crop with respect to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
"""
slice_idx = [c.start for c in batch["crop"][0]]
boxes = [r.float().cpu() for r in result[self.box_key]]
scores = [r.float().cpu() for r in result[self.score_key]]
labels = [r.float().cpu() for r in result[self.label_key]]
# process 2d boxes
tile_size = batch[self.data_key].shape[2:]
centers = [box_center(img_boxes) if img_boxes.numel() > 0 else Tensor([]).to(img_boxes)
for img_boxes in boxes]
weights = [self._get_box_in_tile_weight(c, tile_size) for c in centers]
weights = [w * self.model_weights[self.model_current] for w in weights]
tile_origins = [to[1:] for to in zip(*batch["tile_origin"])]
boxes = self._apply_offsets_to_boxes(boxes, tile_origins)
# convert to 3d boxes
boxes_3d = []
for boxes_image, idx in zip(boxes, slice_idx):
if boxes_image.numel() > 0:
idx_tensor = torch.tensor([[float(idx), float(idx) + 1.]])
idx_tensor_expanded = idx_tensor.to(boxes_image).expand(boxes_image.shape[0], -1)
_boxes_3d = torch.stack([
idx_tensor_expanded[:, 0],
boxes_image[:, 0],
idx_tensor_expanded[:, 1],
boxes_image[:, 2],
boxes_image[:, 1],
boxes_image[:, 3],
], dim=1)
else:
_boxes_3d = boxes_image.view(-1, 6)
boxes_3d.append(_boxes_3d)
self.model_results[self.model_current]["boxes"].extend(boxes_3d)
self.model_results[self.model_current]["scores"].extend(scores)
self.model_results[self.model_current]["labels"].extend(labels)
self.model_results[self.model_current]["weights"].extend(weights)
@torch.no_grad()
def get_case_result(self,
restore: bool = False,
names: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Tensor]:
"""
Process all the batches and models and create the final prediction
Args:
restore: restore prediction in the original image space
names: name of the models to use. By default all models are used.
Returns:
Dict: final result
`pred_boxes`: predicted box locations
[N, dims * 2] (x1, y1, x2, y2, (z1, z2))
`pred_scores`: predicted probability per box [N]
`pred_labels`: predicted label per box [N]
`restore`: indicate whether predictions were restored in
original image space
`original_size_of_raw_data`: image shape befor preprocessing
`itk_origin`: itk origin of image before preprocessing
`itk_spacing`: itk spacing of image before preprocessing
`itk_direction`: itk direction of image before preprocessing
"""
if names is None:
names = list(self.model_results.keys())
boxes, probs, labels, weights = [], [], [], []
for name in names:
_boxes, _probs, _labels, _weights = self.process_model(name)
boxes.append(_boxes)
probs.append(_probs)
labels.append(_labels)
weights.append(_weights)
boxes, probs, labels = self.process_ensemble(
boxes=boxes, probs=probs, labels=labels,
weights=weights,
)
boxes, probs, labels = self.track_2d_to_3d(boxes, probs, labels)
if restore:
boxes = self.restore_prediction(boxes)
return {
"pred_boxes": boxes,
"pred_scores": probs,
"pred_labels": labels,
"restore": restore,
"original_size_of_raw_data": self.properties["original_size_of_raw_data"],
"itk_origin": self.properties["itk_origin"],
"itk_spacing": self.properties["itk_spacing"],
"itk_direction": self.properties["itk_direction"],
}
def track_2d_to_3d(self, boxes: Tensor, probs: Tensor, labels: Tensor):
"""
Converts the 2d tubes to 3d boxes
Args:
boxes: pseudo 3d boxes (each 2d box is projected into 3d space)
probs: probability of each box
labels: label of each box
Returns:
boxes: 3d boxes
probs: probabilities
labels: labels
"""
if self.properties["shape"][0] > 1:
boxes_2d = boxes[:, [1, 4, 3, 5]] # [N, 4]
slices = torch.round(boxes[:, 0]).int() # [N]
merger = self.parameters["track_merger_cls"](
boxes=boxes_2d,
slices=slices,
scores=probs,
labels=labels,
iou_th=self.parameters["track_iou"],
neighbor_slices=self.parameters["track_neighbor_slices"],
)
boxes, probs, labels = merger.merge()
# the upper bound is already correct, we need to fix the lower bound here
boxes[:, 0] = boxes[:, 0] - 1
keep = remove_small_boxes(
boxes, min_size=self.parameters["track_remove_small_boxes"])
boxes, probs, labels = boxes[keep], probs[keep], labels[keep]
return boxes, probs, labels
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from os import PathLike
from pathlib import Path
from typing import Dict, Optional, Tuple, Sequence, Any
import torch
import numpy as np
from loguru import logger
from scipy.ndimage import gaussian_filter
from torch import Tensor
from nndet.inference.ensembler.base import BaseEnsembler
from nndet.inference.restore import restore_fmap
class SegmentationEnsembler(BaseEnsembler):
ID = "seg"
def __init__(self,
seg_key: str = 'pred_seg',
data_key: str = 'data',
**kwargs,
):
"""
Ensemble segmentation predictions from tta and model ensembling
Args:
seg_key: key where segmentation is located inside prediction dict
data_key: key where data is located inside batch dict
use_gaussian: apply gaussian weighting to individual crops
non_lin: non linearity to apply to convert logits to probabilities
(will be applied after consolidation)
argmax: apply argmax to output
kwargs: passed to super class
"""
super().__init__(**kwargs)
self.seg_key = seg_key
self.data_key = data_key
self.model_results: Optional[Tensor] = None
self.overlap = torch.zeros(self.properties["shape"])
self.cache_crop_weight: Dict[Tuple, torch.Tensor] = {}
@classmethod
def from_case(cls,
case: Dict,
properties: Dict,
parameters: Optional[Dict] = None,
seg_key: str = 'pred_seg',
data_key: str = 'data',
**kwargs,
):
"""
Primary way to instantiate this class. Automatically extracts all
properties and uses a default set of parameters for ensembling.
Args:
case: case which is predicted.
mode: operation mode of ensembler (defines which network was used)
e.g. '2d' | '3d'
properties: Additional properties.
Required keys:
`transpose_backward`
`spacing_after_resampling`
`crop_bbox`
parameters: Additional parameters. Defaults to None.
seg_key: key where segmentation is located inside prediction dict
data_key: key where data is located inside batch dict
"""
parameters = parameters if parameters is not None else {}
_parameters = {"use_gaussian": True, "argmax": True}
_parameters.update(parameters)
_properties = {
"shape": case[data_key].shape[1:], # remove channel dim
"transpose_backward": properties["transpose_backward"],
"original_spacing": properties["original_spacing"],
"spacing_after_resampling": properties["spacing_after_resampling"],
"crop_bbox": properties["crop_bbox"],
"size_after_cropping": properties["size_after_cropping"],
"original_size_before_cropping": properties["original_size_of_raw_data"],
"itk_origin": properties["itk_origin"],
"itk_spacing": properties["itk_spacing"],
"itk_direction": properties["itk_direction"],
}
return cls(
properties=_properties,
parameters=_parameters,
seg_key=seg_key,
data_key=data_key,
**kwargs,
)
def add_model(self,
name: Optional[str] = None,
model_weight: Optional[float] = None,
) -> str:
"""
This functions signales the ensembler to add a new model for internal
processing
Args:
name: Name of the model. If None, uses counts the models.
model_weight: Optional weight for this model. Defaults to None.
"""
if name is None:
name = len(self.model_weights) + 1
if name in self.model_weights:
raise ValueError(f"Invalid model name, model {name} is already present")
if model_weight is None:
model_weight = 1.0
self.model_weights[name] = model_weight
self.model_current = name
return name
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any], Dict[str, Sequence[Any]]]:
"""
Not available for segmentation ensembler
Returns:
Dict[str, Sequence[Any]]: Parameters to sweep. The keys define the
parameters wile the Sequences are the values to sweep.
"""
return {}, {}
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.seg_key`: [Tensor]: predicted segmentation [N, C, dims]
batch: input batch
`tile_origin: origin of crop with recard to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
"""
seg_batch = result[self.seg_key].cpu()
crops = batch["crop"]
weight = self.get_weighting(tuple(seg_batch.shape[2:])).to(seg_batch)
seg_batch = seg_batch * weight[None].to(seg_batch) * self.model_weights[self.model_current]
if self.model_results is None:
self.model_results = torch.zeros(
(int(seg_batch.shape[1]), *self.properties["shape"])).to(seg_batch)
for seg, crop in zip(seg_batch, zip(*crops)):
_weight = weight.clone()
seg, case_crop = self.crop_to_case_boundaries(seg, crop)
_weight, case_crop2 = self.crop_to_case_boundaries(_weight[None], crop)
assert case_crop == case_crop2
self.model_results[case_crop] += seg
self.overlap[case_crop] += _weight[0]
def crop_to_case_boundaries(self, seg: torch.Tensor, crop: Sequence[slice]):
"""
In case padding was used at the borders, the padding needs to be removed
Args
seg: predicted segmentation
Sequence[slice]: crop in case to save segmentation
"""
if len(crop) > self.model_results.ndim - 1:
crop = crop[-(self.model_results.ndim - 1):]
crop_slicer = []
case_slicer = []
for dim, c in enumerate(crop):
case_start = max(0, c.start)
case_stop = min(self.model_results.shape[dim + 1], c.stop)
diff_stop = c.stop - self.model_results.shape[dim + 1]
crop_start = max(0, 0 - (c.start - 0)) # 0 added for completeness of pattern
crop_stop = min(seg.shape[dim + 1], seg.shape[dim + 1] - diff_stop)
crop_slicer.append(slice(crop_start, crop_stop, c.step))
case_slicer.append(slice(case_start, case_stop, c.step))
return seg[(..., *crop_slicer)], (..., *case_slicer)
def get_weighting(self, crop_size: Tuple[int]) -> torch.Tensor:
"""
Get matrix to weight predictions inside a single patch
Args:
crop_size: size of crop
Returns:
Tensor: weight for crop
Tuple[int]: size of patch
"""
if crop_size not in self.cache_crop_weight:
if self.parameters["use_gaussian"]:
logger.info(f"Creating new gaussian weight matrix for crop size {crop_size}")
tmp = np.zeros(crop_size)
center_coords = [i // 2 for i in crop_size]
sigmas = [i // 8 for i in crop_size]
tmp[tuple(center_coords)] = 1
tmp_smooth = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
tmp_smooth = tmp_smooth / tmp_smooth.max() * 1
weighting = tmp_smooth + 1e-8
self.cache_crop_weight[crop_size] = torch.from_numpy(weighting).float()
else:
logger.info(f"Creating new weight matrix for crop size {crop_size}")
self.cache_crop_weight[crop_size] = torch.ones(crop_size, dtype=torch.float)
return self.cache_crop_weight[crop_size]
def restore_prediction(self, logit_maps: Tensor) -> Tensor:
"""
Restore predictions in the original image space
Args:
boxes: predicted boxes [N, dims * 2] (x1, y1, x2, y2, (z1, z2))
Returns:
Tensor: boxes in original image space [N, dims * 2]
(x1, y1, x2, y2, (z1, z2))
"""
_old_dtype = logit_maps.dtype
logit_maps_np = restore_fmap(
fmap=logit_maps.detach().cpu().numpy(),
transpose_backward=self.properties["transpose_backward"],
original_spacing=self.properties["original_spacing"],
spacing_after_resampling=self.properties["spacing_after_resampling"],
original_size_before_cropping=self.properties["original_size_before_cropping"],
size_after_cropping=self.properties["size_after_cropping"],
crop_bbox=self.properties["crop_bbox"],
interpolation_order=1,
interpolation_order_z=0,
do_separate_z=None,
)
logit_maps = torch.from_numpy(logit_maps_np).to(dtype=_old_dtype)
return logit_maps
@torch.no_grad()
def get_case_result(self,
restore: bool = False, **kwargs
) -> Dict[str, Tensor]:
"""
Get final result for case after ensembling and TTA
Returns:
Dict: results
`pred_seg`: [C, dims] if :param:`self.argmax`
if False and [dims] if True
`restore`: indicate whether predictions were restored in
original image space
`itk_origin`: itk origin of image before preprocessing
`itk_spacing`: itk spacing of image before preprocessing
`itk_direction`: itk direction of image before preprocessing
"""
result = self.model_results / self.overlap[None]
if restore:
result = self.restore_prediction(result)
if self.parameters["argmax"]:
result = result.argmax(dim=0).to(dtype=torch.uint8)
self.case_result = result
return {
"pred_seg": self.case_result,
"restore": restore,
"itk_origin": self.properties["itk_origin"],
"itk_spacing": self.properties["itk_spacing"],
"itk_direction": self.properties["itk_direction"],
}
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name
Args:
target_dir: folder to save result to
name: name of case
"""
super().save_state(
target_dir=target_dir,
name=name,
seg_key=self.seg_key,
data_key=self.data_key,
case_crop_weight=self.cache_crop_weight,
**kwargs,
)
@classmethod
def from_checkpoint(cls, base_dir: PathLike, case_id: str, **kwargs):
ckp = torch.load(str(Path(base_dir) / f"{case_id}_{cls.ID}.pt"))
t = cls(
properties=ckp["properties"],
parameters=ckp["parameters"],
seg_key=ckp["seg_key"],
data_key=ckp["data_key"],
)
t._load(ckp)
return t
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from pathlib import Path
from typing import Sequence, List, Dict, Callable, Optional
import numpy as np
from loguru import logger
from nndet.utils.tensor import to_numpy
from nndet.io.load import load_pickle, save_pickle
from nndet.io.paths import Pathlike, get_case_id_from_path
from nndet.inference.loading import load_time_ensemble
def predict_dir(
source_dir: Pathlike,
target_dir: Pathlike,
cfg: dict,
plan: dict,
source_models: Path,
model_fn: Callable[[Path, dict, dict, int], Sequence[dict]] = load_time_ensemble,
num_models: int = None,
num_tta_transforms: int = None,
restore: bool = False,
case_ids: Optional[Sequence[str]] = None,
save_state: bool = False,
**kwargs
):
"""
Predict all preprocessed(!) cases inside a directory
Args:
source_dir: directory where preprocessed cases are located
target_dir: directory to save predictions to
cfg: config
`predictor`: define predictor to use
plan: plan
source_models: directory where models for prediction are located
model_fn: function to load model from directory
num_models: number of models to use for prediction; None = all
num_tta_transforms: number of tta transforms to use for
prediction; None = all
stage: current stage to predict
restore: restore predictions in original image space
case_ids: case ids to predict. If None the whole folder will be
predicted
save_state: If `true` the state of the ensembler is saved. If
`false` only the final result is saved.
kwargs: passed to :method:'get_predictor' method of module
"""
logger.info("Running inference")
source_dir = Path(source_dir)
target_dir = Path(target_dir)
models = model_fn(source_models, cfg, plan, num_models)
predictor = models[0]["model"].get_predictor(
plan=plan,
models=[m["model"] for m in models],
num_tta_transforms=num_tta_transforms,
**kwargs,
)
if case_ids is None:
case_paths = list(source_dir.glob('*.npz'))
case_paths = [cp for cp in case_paths if "_gt.npz" not in str(cp)]
else:
case_paths = [source_dir / f"{cid}.npz" for cid in case_ids]
logger.info(f"Found {len(case_paths)} files for inference.")
for idx, path in enumerate(case_paths, start=1):
logger.info(f"Predicting case {idx} of {len(case_paths)}.")
case_id = get_case_id_from_path(str(path), remove_modality=False)
if path.is_file():
case = np.load(str(path), allow_pickle=True)['data']
else:
case = np.load(str(path)[:-4] + ".npy", allow_pickle=True)
properties = load_pickle(path.parent / f"{case_id}.pkl")
properties["transpose_backward"] = plan["transpose_backward"]
if save_state:
_ = predictor.predict_case({"data": case},
properties,
save_dir=target_dir,
case_id=case_id,
restore=restore,
)
else:
result = predictor.predict_case({"data": case},
properties,
save_dir=None,
case_id=None,
restore=restore,
)
for key, item in to_numpy(result).items():
save_pickle(item, target_dir / f"{case_id}_{key}.pkl")
return predictor
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from functools import partial
from pathlib import Path
from typing import Sequence, Optional
import torch
from loguru import logger
from nndet.ptmodule import MODULE_REGISTRY
from nndet.io.paths import Pathlike
def get_loader_fn(mode: str, **kwargs):
if mode == "best":
load_fn = partial(load_time_ensemble, **kwargs)
elif mode == "final":
load_fn = partial(load_final_model, **kwargs)
elif mode == "latest":
load_fn = partial(load_final_model, identifier="latest", **kwargs)
else:
raise ValueError(f"Unknown mode {mode}")
return load_fn
def get_latest_model(base_dir: Pathlike, fold: int = 0) -> Optional[Path]:
"""
Get the latest training dir in a given base dir
E.g. ../RetinaUNetV0/fold0__0, ../RetinaUNetV0/fold0__1
-> would select fold0__1
Args:
base_dir: path to base dir
fold: fold to look for
Returns:
Optional[Path]: If no model for specified fold is found, None this
will return None
"""
base_dir = Path(base_dir)
m = [m for m in base_dir.iterdir() if m.is_dir()]
m = [_m for _m in m if f"fold{fold}" in _m.stem]
if m:
return sorted(m, key=lambda x: x.stem, reverse=True)[0]
else:
return None
# TODO: update
def load_time_ensemble(
source_models: Path,
cfg: dict,
plan: dict,
num_models: int = None,
) -> Sequence[dict]:
"""
Load time ensembled models
Args:
source_models: path to directory where models are saved
cfg: config used for experiment
`model`: name of model in DETECTION_REGISTRY
plan: plan used for training
num_models: number of models to load
Returns:
Sequence[dict]: loaded models
`model`: loaded model
`rank`: rank of model
"""
logger.info("Loading time ensemble")
model_names = list(source_models.glob('model_best*.ckpt'))
if not model_names:
raise RuntimeError(f"Did not find any models in {source_models}")
models = []
for path in model_names:
model = MODULE_REGISTRY[cfg["module"]](
model_cfg=cfg["model_cfg"],
trainer_cfg=cfg["trainer_cfg"],
plan=plan,
)
state_dict = torch.load(path, map_location="cpu")["state_dict"]
t = model.load_state_dict(state_dict)
logger.info(f"Loaded {path} with {t}")
model.float()
model.eval()
rank = int(str(path).rsplit(os.sep, 1)[-1][10])
models.append({"model": model.cpu(), "rank": rank})
if num_models is not None:
models = models[:num_models]
logger.info(f"Using {len(models)} models for for inference.")
return models
def load_final_model(
source_models: Path,
cfg: dict,
plan: dict,
num_models: int = 1,
identifier: str = "last",
) -> Sequence[dict]:
"""
Load final model from training
Args:
source_models: path to directory where models are saved
cfg: config used for experiment
`model`: name of model in DETECTION_REGISTRY
plan: plan used for training
num_models: Only supports one model
identifier: looks for identifier inside of model name
Returns:
Sequence[dict]: loaded models
`model`: loaded model
`rank`: rank is always 0
"""
assert num_models == 1, f"load_final_model only supports num_models=1, found {num_models}"
logger.info(f"Loading {identifier} model")
model_names = list(source_models.glob('*.ckpt'))
model_names = [m for m in model_names if identifier in str(m.stem)]
assert len(model_names) == 1, f"Found wrong number of models, {model_names} in {source_models} with {identifier}"
path = model_names[0]
model = MODULE_REGISTRY[cfg["module"]](
model_cfg=cfg["model_cfg"],
trainer_cfg=cfg["trainer_cfg"],
plan=plan,
)
state_dict = torch.load(path, map_location="cpu")["state_dict"]
t = model.load_state_dict(state_dict)
logger.info(f"Loaded {path} with {t}")
model.float()
model.eval()
return [{"model": model, "rank": 0}]
def load_all_models(
source_models: Path,
cfg: dict,
plan: dict,
*args,
**kwargs,
):
"""
Load all models to ensemble
Args:
source_models: path to directory where models are saved
cfg: config used for experiment
`model`: name of model in DETECTION_REGISTRY
plan: plan used for training
kwargs: not used
Returns:
Sequence[dict]: loaded models
`model`: loaded model
`rank`: rank of model
"""
model_names = list(source_models.glob('*.ckpt'))
if not model_names:
raise RuntimeError(f"Did not find any models in {source_models}")
logger.info(f"Found {len(model_names)} models to ensemble")
models = []
for path in model_names:
model = MODULE_REGISTRY[cfg["module"]](
model_cfg=cfg["model_cfg"],
trainer_cfg=cfg["trainer_cfg"],
plan=plan,
)
state_dict = torch.load(path, map_location="cpu")["state_dict"]
t = model.load_state_dict(state_dict)
logger.info(f"Loaded {path} with {t}")
model.float()
model.eval()
models.append({"model": model.cpu()})
return models
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import typing
import itertools
import numpy as np
from loguru import logger
from skimage.measure import regionprops
import SimpleITK as sitk
def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
) -> typing.List[tuple]:
"""
Creates indices to crop patches around individual objects in mask
Args
mask: mask where objects have diffrent numbers. Objects need to be numbered
consequtively from one to n, with 0 as background.
cshape: size of individual crops. Needs to be divisible by two.
Otherwise crops do not have the expected size.
If cshape is a int, crops will have the same size in every dimension.
Returns
list[tuple]: each crop generates one tuple with indices
Raises
TypeError: raised if mask and patches define different dimensionalities
TypeError: raised if `cshape` is larger than mask
See Also
:func:`save_get_crop`
Warnings
The returned crops are not checked for image boundaries. Slices
with negative indices and indices which extend over the mask boundaries
are possible! To correct for this, use `save_get_crop` which handles
this exceptions.
"""
if isinstance(cshape, int):
cshape = tuple([cshape] * mask.ndim)
if mask.ndim != len(cshape):
raise TypeError("Size of crops needs to be defined for "
"every dimension")
if any(np.subtract(mask.shape, cshape) < 0):
raise TypeError("Patches must be smaller than data.")
if mask.max() == 0:
# no objects in mask
return []
all_centroids = [i['centroid'] for i in regionprops(mask.astype(np.int32))]
crops = []
for centroid in all_centroids:
crops.append(tuple(slice(int(c) - (s // 2), int(c) + (s // 2))
for c, s in zip(centroid, cshape)))
return crops
def center_crop_object_seg(seg: np.ndarray, cshape: typing.Union[tuple, int],
**kwargs) -> typing.List[tuple]:
"""
Creates indices to crop patches around individual objects in segmentation.
Objects are determined by region growing with connected threshold.
Args
seg: semantic segmentation of objects.
cshape: size of individual crops. Needs to be divisible by two.
Otherwise crops do not have the expected size.
If cshape is a int, crops will have the same size in every dimension.
kwargs: additional keyword arguments passed to `center_crop_objects_mask`
Returns
list[tuple]: each crop generates one tuple with indices
See Also
:func:`save_get_crop`
Warnings
The returned crops are not checked for image boundaries. Slices
with negative indices and indices which extend over the mask boundaries
are possible! To correct for this, use `save_get_crop` which handles
this exceptions.
"""
_mask, _ = create_mask_from_seg(seg)
return center_crop_object_mask(_mask, cshape=cshape, **kwargs)
def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]:
"""
Create a mask where objects are enumerated from 1, ..., n.
Objects are determined by region growing with connected threshold.
Args
seg: semantic segmentation array
Returns
np.ndarray: mask with objects
list: classes to objects (ascending order)
"""
_seg = np.copy(seg).astype(np.int32)
_seg_sitk = sitk.GetImageFromArray(_seg)
_mask = np.zeros_like(seg).astype(np.int32)
_obj_cls = []
_obj = 1
while _seg.max() > 0:
# choose one seed in segmentation
seed = np.transpose(np.nonzero(_seg))[0]
# invert coordinates for sitk
seed_sitk = tuple(seed[:: -1].tolist())
seed = tuple(seed)
# region growing
seg_con = sitk.ConnectedThreshold(_seg_sitk,
seedList=[seed_sitk],
lower=int(_seg[seed]),
upper=int(_seg[seed]))
seg_con = sitk.GetArrayFromImage(seg_con).astype(bool)
# add object to mask
_mask[seg_con] = _obj
_obj_cls.append(_seg[seed])
# remove object from segmentation
_seg[seg_con] = 0
_obj += 1
# objects should never overlap
assert _mask.max() < _obj
return _mask, _obj_cls
def create_grid(cshape: typing.Union[typing.Sequence[int], int],
dshape: typing.Sequence[int],
overlap: typing.Union[typing.Sequence[int], int] = 0,
mode='fixed',
center_boarder: bool = False,
**kwargs,
) -> typing.List[typing.Tuple[slice]]:
"""
Create indices for a grid
Args
cshape: size of individual patches
dshape: shape of data
overlap: overlap between patches. If `overlap` is an integer is is applied
to all dimensions.
mode: defines how borders should be handled, by default 'fixed'.
`fixed` created patches without special handling of borders, thus
the last patch might exceed `dshape`
`symmetric` moves patches such that the the first and last patch are
equally overlapping of dshape (when combined with padding, the last and
first patch would have the same amount of padding)
center_boarder: adds additional crops at the boarders which have the
boarder as their center
Returns
typing.List[slice]: slices to extract patches
Raises
TypeError: raised if `cshape` and `dshape` do not have the same length
TypeError: raised if `overlap` and `dshape` do not have the same length
TypeError: raised if `cshape` is larger than `dshape`
TypeError: raised if `overlap` is larger than `cshape`
Warnings
The returned crops are can exceed the image boundaries. Slices
with negative indices and indices which extend over the image
boundary at the start. To correct for this, use `save_get_crop`
which handles exceptions at borders.
"""
_mode_fn = {
"fixed": _fixed_slices,
"symmetric": _symmetric_slices,
}
if len(dshape) == 3 and len(cshape) == 2:
logger.info("Creating 2d grid.")
slices_3d = dshape[0]
dshape = dshape[1:]
else:
slices_3d = None
# create tuples from shapes
if isinstance(cshape, int):
cshape = tuple([cshape] * len(dshape))
if isinstance(overlap, int):
overlap = tuple([overlap] * len(dshape))
# check shapes
if len(cshape) != len(dshape):
raise TypeError(
"cshape and dshape must be defined for same dimensionality.")
if len(overlap) != len(dshape):
raise TypeError(
"overlap and dshape must be defined for same dimensionality.")
if any(np.subtract(dshape, cshape) < 0):
axes = np.nonzero(np.subtract(dshape, cshape) < 0)
logger.warning(f"Found patch size which is bigger than data: data {dshape} patch {cshape}")
if any(np.subtract(cshape, overlap) < 0):
raise TypeError("Overlap must be smaller than size of patches.")
grid_slices = [_mode_fn[mode](psize, dlim, ov, **kwargs)
for psize, dlim, ov in zip(cshape, dshape, overlap)]
if center_boarder:
for idx, (psize, dlim, ov) in enumerate(zip(cshape, dshape, overlap)):
lower_bound_start = int(-0.5 * psize)
upper_bound_start = dlim - int(0.5 * psize)
grid_slices[idx] = tuple([
slice(lower_bound_start, lower_bound_start + psize),
*grid_slices[idx],
slice(upper_bound_start, upper_bound_start + psize),
])
if slices_3d is not None:
grid_slices = [tuple([slice(i, i + 1) for i in range(slices_3d)])] + grid_slices
grid = list(itertools.product(*grid_slices))
return grid
def _fixed_slices(psize: int, dlim: int, overlap: int, start: int = 0) -> typing.Tuple[slice]:
"""
Creates fixed slicing of a single axis. Only last patch exceeds dlim.
Args
psize: size of patch
dlim: size of data
overlap: overlap between patches
start: where to start patches, by default 0
Returns
typing.List[slice]: ordered slices for a single axis
"""
upper_limit = 0
lower_limit = start
idx = 0
crops = []
while upper_limit < dlim:
if idx != 0:
lower_limit = lower_limit - overlap
upper_limit = lower_limit + psize
crops.append(slice(lower_limit, upper_limit))
lower_limit = upper_limit
idx += 1
return tuple(crops)
def _symmetric_slices(psize: int, dlim: int, overlap: int) -> typing.Tuple[slice]:
"""
Creates symmetric slicing of a single axis. First and last patch exceed
data borders.
Args
psize: size of patch
dlim: size of data
overlap: overlap between patches
start: where to patches, by default 0
Returns
typing.List[slice]: ordered slices for a single axis
"""
if psize >= dlim:
return _fixed_slices(psize, dlim, overlap, start=-(psize - dlim) // 2)
pmod = dlim % (psize - overlap)
start = (pmod - psize) // 2
return _fixed_slices(psize, dlim, overlap, start=start)
def save_get_crop(data: np.ndarray,
crop: typing.Sequence[slice],
mode: str = "shift",
**kwargs,
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
"""
Safely extract crops from data
Args
data: list or tuple with data where patches are extracted from
crop: contains the coordiates of a single crop as slices
mode: Handling of borders when crops are outside of data, by default "shift".
Following modes are supported: "shift" crops are shifted inside the
data | other modes are identical to `np.pad`
kwargs: additional keyword arguments passed to `np.pad`
Returns
list[np.ndarray]: crops from data
Tuple[int]: origin offset of crop with regard to data origin (can be
used to offset bounding boxes)
Tuple[slice]: crop from data used to extract information
See Also
:func:`center_crop_objects_mask`, :func:`center_crop_objects_seg`
Warnings
This functions only supports positive indexing. Negative indices are
interpreted like they were outside the lower boundary!
"""
if len(crop) > data.ndim:
raise TypeError(
"crop must have smaller or same dimensionality as data.")
if mode == 'shift':
# move slices if necessary
return _shifted_crop(data, crop)
else:
# use np.pad if necessary
return _padded_crop(data, crop, mode, **kwargs)
def _shifted_crop(data: np.ndarray,
crop: typing.Sequence[slice],
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
"""
Created shifted crops to handle borders
Args
data: crop is extracted from data
crop: defines boundaries of crops
Returns
List[np.ndarray]: list of crops
Tuple[int]: origin offset of crop with regard to data origin (can be
used to offset bounding boxes)
Tuple[slice]: crop from data used to extract information
Raises
TypeError: raised if patchsize is bigger than data
Warnings
This functions only supports positive indexing. Negative indices are
interpreted like they were outside the lower boundary!
"""
shifted_crop = []
dshape = tuple(data.shape)
# index from back, so batch and channel dimensions must not be defined
axis = data.ndim - len(crop)
for idx, crop_dim in enumerate(crop):
if crop_dim.start < 0:
# start is negative, thus it is subtracted from stop
new_slice = slice(0, crop_dim.stop - crop_dim.start, crop_dim.step)
if new_slice.stop > dshape[axis + idx]:
raise RuntimeError(
"Patch is bigger than entire data. shift "
"is not supported in this case.")
shifted_crop.append(new_slice)
elif crop_dim.stop > dshape[axis + idx]:
new_slice = \
slice(crop_dim.start - (crop_dim.stop - dshape[axis + idx]),
dshape[axis + idx], crop_dim.step)
if new_slice.start < 0:
raise RuntimeError(
"Patch is bigger than entire data. shift "
"is not supported in this case.")
shifted_crop.append(new_slice)
else:
shifted_crop.append(crop_dim)
origin = [int(x.start) for x in shifted_crop]
return data[tuple([..., *shifted_crop])], origin, shifted_crop
def _padded_crop(data: np.ndarray,
crop: typing.Sequence[slice],
mode: str,
**kwargs,
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
"""
Extract patch from data and pad accordingly
Args
data: crop is extracted from data
crop: defines boundaries of crops
mode: mode for padding. See `np.pad` for more details
kwargs: additional keyword arguments passed to :func:`np.pad`
Returns
typing.List[np.ndarray]: list of crops
Tuple[int]: origin offset of crop with regard to data origin (can be
used to offset bounding boxes)
Tuple[slice]: crop from data used to extract information
"""
clipped_crop = []
dshape = tuple(data.shape)
# index from back, so batch and channel dimensions must not be defined
axis = data.ndim - len(crop)
padding = [(0, 0)] * axis if axis > 0 else []
for idx, crop_dim in enumerate(crop):
lower_pad = 0
upper_pad = 0
lower_bound = crop_dim.start
upper_bound = crop_dim.stop
# handle lower bound
if lower_bound < 0:
lower_pad = -lower_bound
lower_bound = 0
# handle upper bound
if upper_bound > dshape[axis + idx]:
upper_pad = upper_bound - dshape[axis + idx]
upper_bound = dshape[axis + idx]
padding.append((lower_pad, upper_pad))
clipped_crop.append(slice(lower_bound, upper_bound, crop_dim.step))
origin = [int(x.start) for x in crop]
return (np.pad(data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs),
origin,
clipped_crop,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import time
import torch
import copy
import collections
import numpy as np
from loguru import logger
from typing import Hashable, List, Sequence, Dict, Union, Any, Optional, Callable, TypeVar
from pathlib import Path
from nndet.io.load import save_pickle
from nndet.models.abstract import AbstractModel
from nndet.io.transforms import NoOp
from nndet.inference.patching import save_get_crop, create_grid
from nndet.utils import to_device, maybe_verbose_iterable
from rising.transforms import AbstractTransform
from rising.loading import DataLoader
__all__ = ["Predictor"]
torch_device = Union[torch.device, str]
class Predictor:
def __init__(self,
ensembler: Dict[str, Callable],
models: Sequence[AbstractModel],
crop_size: Sequence[int],
overlap: float = 0.5,
tile_keys: Sequence[str] = ('data',),
model_keys: Sequence[str] = ('data',),
tta_transforms: Sequence[AbstractTransform] = (NoOp(),),
tta_inverse_transforms: Sequence[AbstractTransform] = (NoOp(),),
pre_transform: AbstractTransform = None,
post_transform: AbstractTransform = None,
batch_size: int = 4,
model_weights: Sequence[float] = None,
device: torch_device = "cuda:0",
ensemble_on_device: bool = True,
):
"""
Predict entire cases with TTA and Model-Ensembling
Workflow
- Load whole patient
-> create predictor from patient
- tile patient
* for each model:
* for each batch (batches of tiles):
* for each tta transform:
- pre transform
- tta transform
- post transform
- predict batch
- inverse tta transform
- forward predictions and batch to ensembler classes
<- return patient result
Args:
ensembler: Callable to instantiate ensembler from case and
properties
models: models to ensemble
crop_size: size of each crop (for most cases this should be
the same as in training)
overlap: overlap of crops
tile_keys: keys which are tiles
model_keys: this kyes are passed as positional arugments to the
model
tta_transforms: tta transformations
tta_inverse_transforms: inverse tta transformation
pre_transform: transform which is performed before every tta
transform
post_transform: transform which is performed after every tta
transform
batch_size: batch size to use for prediction
model_weights: additional weighting of individual models
device: device used for prediction
ensemble_on_device: The results will be passed to the ensembler
class with the current device. The ensembler needs to make
sure to avoid memory leaks!
"""
self.ensemble_on_device = ensemble_on_device
self.device = device
self.ensembler_fns = ensembler
self.ensembler = {}
self.models = models
self.model_weights = [1.] * len(models) if model_weights is None else model_weights
self.crop_size = crop_size
self.overlap = overlap
self.tile_keys = tile_keys
self.model_keys = model_keys
self.batch_size = batch_size
if len(tta_transforms) != len(tta_inverse_transforms):
raise ValueError("Every tta transform needs a reverse transform")
self.tta_transforms = tta_transforms
self.tta_inverse_transforms = tta_inverse_transforms
self.post_transform = post_transform
self.pre_transform = pre_transform
self.grid_mode = 'symmetric'
self.save_get_mode = 'shift'
@classmethod
def create(cls, *args, **kwargs):
"""
Create predictor object with specific ensembler objects
Raises:
NotImplementedError: Need to be overwritten in subclasses
"""
raise NotImplementedError
@classmethod
def get_ensembler(cls, key: Hashable, dim: int) -> Callable:
"""
Return ensembler class for specific keys
Typically: `boxes`, `seg`, `instances`
Args:
key: Key to return
dim: number of spatial dimensions the network expects
Raises:
NotImplementedError: Need to be overwritten in subclasses
Returns:
Callable: Ensembler class
"""
raise NotImplementedError
def predict_case(self,
case: Dict,
properties: Optional[Dict] = None,
save_dir: Optional[Union[Path, str]] = None,
case_id: Optional[str] = None,
restore: bool = False,
) -> dict:
"""
Load and predict a single case.
Args:
case: data of a single case
properties: additional properties of the case. E.g. to
restore prediction in original image space
save_dir: directory to save predictions
case_id: used for saving
restore: restore prediction in original image space
("revert" preprocessing)
Returns:
dict: result of each ensembler (converted to numpy)
"""
tic = time.perf_counter()
for name, fn in self.ensembler_fns.items():
self.ensembler[name] = fn(case, properties=properties)
tiles = self.tile_case(case)
self.predict_tiles(tiles)
result = {key: value.get_case_result(restore=restore) for key, value in self.ensembler.items()}
if save_dir is not None:
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
for ensembler in self.ensembler.values():
ensembler.save_state(save_dir, name=case_id)
save_pickle(properties, save_dir / f"{case_id}_properties.pkl")
toc = time.perf_counter()
logger.info(f"Prediction took {toc - tic} s")
return result
def tile_case(self, case: dict, update_remaining: bool = True) -> \
Sequence[Dict[str, np.ndarray]]:
"""
Create patches from whole patient for prediction
Args:
case: data of a single case
update_remaining: properties from case which are not tiles
are saved into all patches
Returns:
Sequence[Dict[str, np.ndarray]]: extracted crops from case
and added new key:
`tile_origin`: Sequence[int] offset of tile relative
to case origin
"""
dshape = case[self.tile_keys[0]].shape
overlap = [int(c * self.overlap) for c in self.crop_size]
crops = create_grid(
cshape=self.crop_size,
dshape=dshape[1:],
overlap=overlap,
mode=self.grid_mode,
)
tiles = []
for crop in crops:
try:
# try selected extraction mode
tile = {key: save_get_crop(case[key], crop, mode=self.save_get_mode)[0]
for key in self.tile_keys}
_, tile["tile_origin"], tile["crop"] = save_get_crop(
case[self.tile_keys[0]], crop, mode=self.save_get_mode)
except RuntimeError:
# fallback to symmetric
logger.warning("Path size is bigger than whole case, padding case to match patch size")
tile = {key: save_get_crop(case[key], crop, mode="symmetric")[0]
for key in self.tile_keys}
_, tile["tile_origin"], tile["crop"] = save_get_crop(
case[self.tile_keys[0]], crop, mode="symmetric")
if update_remaining:
tile.update({key: item for key, item in case.items()
if key not in self.tile_keys})
tiles.append(tile)
return tiles
@torch.no_grad()
def predict_tiles(self, tiles: Sequence[Dict]) -> None:
"""
Predict tiles of a single case with ensembling and tta. Results
are saved inside ensemblers
Args:
tiles: tiles from single case
"""
dataloader = DataLoader(tiles,
batch_size=self.batch_size,
shuffle=False,
collate_fn=slice_collate,
)
for model_idx, (model, model_weight) in enumerate(
zip(self.models, self.model_weights)):
logger.info(f"Predicting model {model_idx + 1} of "
f"{len(self.models)} with weight {model_weight}.")
model.to(device=self.device)
model.eval()
for t, (transform, inverse_transform) in enumerate(maybe_verbose_iterable(
list(zip(self.tta_transforms, self.tta_inverse_transforms)),
desc="Transform", position=0)):
for ensembler in self.ensembler.values():
ensembler.add_model(name=f"model{model_idx}_t{t}", model_weight=model_weight)
for batch_num, batch in enumerate(maybe_verbose_iterable(
dataloader, desc="Crop", position=1)):
self.predict_with_transformation(
model=model,
batch=batch,
batch_num=batch_num,
transform=transform,
inverse_transform=inverse_transform,
)
model.cpu()
torch.cuda.empty_cache()
def predict_with_transformation(self,
model: AbstractModel,
batch: Dict,
batch_num: int,
transform: Callable,
inverse_transform: Callable,
):
"""
Run prediction with the specified transformations
Args:
model: model to predict
batch: input batch to model
batch_num: batch index
transform: transform to apply to batch.
inverse_transform: inverse transform to apply to batch and resuls
"""
batch = to_device(batch, device=self.device)
if self.pre_transform is not None:
batch = self.pre_transform(**batch)
transformed = transform(**batch)
if self.post_transform is not None:
transformed = self.post_transform(**transformed)
inp = [transformed[key] for key in self.model_keys]
with torch.cuda.amp.autocast():
result = model.inference_step(*inp, batch_num=batch_num)
result = inverse_transform(**result)
if not self.ensemble_on_device:
result = to_device(result, device="cpu")
for ensembler in self.ensembler.values():
ensembler.process_batch(result=result, batch=batch)
def slice_collate(batch: List[Any]):
"""
Add support for slices to collate function
Args:
batch: batch to collate
Returns:
Any: collated items
"""
elem = batch[0]
elem_type = type(elem)
if isinstance(batch[0], slice):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: slice_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(slice_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
transposed = zip(*batch)
return [slice_collate(samples) for samples in transposed]
else:
return torch.utils.data._utils.collate.default_collate(batch)
PredictorType = TypeVar('PredictorType', bound=Predictor)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence, Tuple, Union, Optional
import numpy as np
from loguru import logger
from nndet.detection.boxes.utils import permute_boxes, expand_to_boxes
from nndet.preprocessing.resampling import resample_data_or_seg, get_do_separate_z, get_lowres_axis
def restore_detection(boxes: np.ndarray,
transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
crop_bbox: Sequence[Tuple[int, int]],
**kwargs,
) -> np.ndarray:
"""
Restore boxes from preprocessed space into original space
Args:
boxes: predicted boxes in preprocessing space,
(x1, y1, x2, y2, (z1, z2))[N, dims * 2]
transpose_backward: backward transposing
original_spacing: spacing of the original image
spacing_after_resampling: spacing in the preprocessed space
(forward transposed order)
crop_bbox: bounding box crop in the original spacing [(min_0, max_0), ...]
**kwargs: ignored
Returns:
np.ndarray: predicted bounding boxes in the original image space
"""
boxes_transposed = permute_boxes(boxes, transpose_backward)
original_spacing = np.asarray(original_spacing)
spacing_after_resampling = np.asarray(spacing_after_resampling)
resampled_spacing = spacing_after_resampling[transpose_backward]
scaling = resampled_spacing / original_spacing
scaling_expanded = expand_to_boxes(scaling[None])
boxes_scaled = boxes_transposed * scaling_expanded
offset = np.asarray([i[0] for i in crop_bbox])
offset_expanded = expand_to_boxes(offset[None])
boxes_original = boxes_scaled + offset_expanded
return boxes_original
def restore_fmap(fmap: np.ndarray,
transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
original_size_before_cropping: Sequence[int],
size_after_cropping: Sequence[int],
crop_bbox: Optional[Sequence[Tuple[int, int]]] = None,
interpolation_order: int = 3,
interpolation_order_z: int = 0,
do_separate_z: bool = None,
) -> np.ndarray:
"""
Restore feature map from preprocessed space into original space
Args:
fmap: feature map to resample [C, dims], where C is the number of
channels
transpose_backward: backward transposing
original_spacing: spacing of the original image
spacing_after_resampling: spacing in the preprocessed space
(forward transposed order)
original_size_before_cropping: original image size before cropping
size_after_cropping: image size after cropping
crop_bbox: bounding box of crop
interpolation_order: interpolation order for inplane axis
interpolation_order_z: interpolation order for anisotropic axis
do_separate_z: if None then we dynamically decide how to resample
along z, if True/False then always/never resample along z
separately. Do not touch unless you know what you are doin
Returns:
np.ndarray: resampled feature map [C, new_dims]
"""
fmap_transposed = np.transpose(fmap, [0] + [i + 1 for i in transpose_backward])
original_spacing = np.asarray(original_spacing)
spacing_after_resampling = np.asarray(spacing_after_resampling)
resampled_spacing = spacing_after_resampling[transpose_backward]
if np.any([i != j for i, j in zip(fmap.shape[1:], size_after_cropping)]):
lowres_axis = _get_lowres_axes(original_spacing, resampled_spacing,
do_separate_z=do_separate_z)
logger.info(f"Resampling: do separate z: {do_separate_z}; lowres axis: {lowres_axis}")
fmap_old_spacing = resample_data_or_seg(fmap, size_after_cropping, is_seg=False,
axis=lowres_axis, order=interpolation_order,
do_separate_z=do_separate_z, cval=0,
order_z=interpolation_order_z)
else:
logger.info(f"Resampling: no resampling necessary")
fmap_old_spacing = fmap_transposed
if crop_bbox is not None:
crop_bbox = [list(cb) for cb in crop_bbox]
tmp = np.zeros((fmap_old_spacing.shape[0], *original_size_before_cropping))
for c in range(len(crop_bbox)):
crop_bbox[c][1] = np.min(
(crop_bbox[c][0] + fmap_old_spacing.shape[c + 1], original_size_before_cropping[c]))
_slices = [...] + [slice(b[0], b[1]) for b in crop_bbox]
tmp[_slices] = fmap_old_spacing
fmap_original = tmp
else:
fmap_original = fmap_old_spacing
return fmap_original
def _get_lowres_axes(original_spacing: Sequence[float],
resampled_spacing: Sequence[float],
do_separate_z: bool) -> Union[Sequence[int], None]:
"""
Dynamically determine lowres axes
Args:
original_spacing: original spacing (not transposed!)
resampled_spacing: resampled sapcing (not transposed!)
do_separate_z: force sepearte
Returns:
Union[Sequence[int], None]: Lowres axes. If None, no lowres axes
is present
"""
if do_separate_z is None:
if get_do_separate_z(original_spacing): # original spacing was anisotropic
do_separate_z = True
lowres_axis = get_lowres_axis(original_spacing)
elif get_do_separate_z(resampled_spacing): # resampled spacing was anisotropic
do_separate_z = True
lowres_axis = get_lowres_axis(resampled_spacing)
else: # no separate z
do_separate_z = False
lowres_axis = None
else:
if do_separate_z:
lowres_axis = get_lowres_axis(original_spacing)
else:
lowres_axis = None
return lowres_axis
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABC, abstractmethod
from pathlib import Path
import time
from typing import Callable, Tuple, Dict, Sequence, Any, Optional, TypeVar
import numpy as np
from loguru import logger
from nndet.io.paths import Pathlike
from nndet.io.load import save_json
from nndet.utils.info import maybe_verbose_iterable
from nndet.utils import to_numpy
from nndet.evaluator.registry import BoxEvaluator
class Sweeper(ABC):
def __init__(self,
classes: Sequence[str],
pred_dir: Pathlike,
gt_dir: Pathlike,
target_metric: str,
save_dir: Optional[Pathlike] = None,
):
"""
Sweep multiple parameters and compute evaluation metrics
to determine the best set of parameters
Args:
evaluation: reference to an evaluation objects
pred_dir: directory where predicted data is saved
device: device to use for internal computations
"""
self.classes = classes
self.save_dir = save_dir if save_dir is None else Path(save_dir)
if self.save_dir is not None:
self.save_dir.mkdir(parents=True, exist_ok=True)
self.target_metric = target_metric
self.device = "cpu"
self.pred_dir = Path(pred_dir)
self.gt_dir = Path(gt_dir)
@abstractmethod
def run_postprocessing_sweep(self,
restore: bool = True,
) -> Tuple[Dict, Dict]:
"""
Run parameter sweeps to determine best parameters
accoring to target metric
Args:
target_metric: metric to optimize
Returns:
Dict: determined parameters
Dict: final results with parameters
"""
raise NotImplementedError
class BoxSweeper(Sweeper):
def __init__(self,
classes: Sequence[str],
pred_dir: Pathlike,
gt_dir: Pathlike,
target_metric: str,
ensembler_cls: Callable,
save_dir: Optional[Pathlike] = None,
) -> None:
"""
Run sweep over parameters and select the best
Args:
classes: classes present in dataset
pred_dir: directory where predictions are saved
gt_dir: directory where ground truth is saved
target_metric: metric to optimize
ensembler_cls: ensembler class used during prediction
save_dir: Directory to save results. Defaults to None.
"""
super().__init__(classes=classes,
pred_dir=pred_dir,
gt_dir=gt_dir,
target_metric=target_metric,
save_dir=save_dir,
)
self.evaluator_cls = BoxEvaluator
self.ensembler_cls = ensembler_cls
def run_postprocessing_sweep(self):
"""
Sequentially search for the best parameters
Returns:
Dict: final parameters to run inference on new cases
Dict:
`det_scores`: detection score metrics
`det_curves`: detection curves
"""
state, sweep_params = self.ensembler_cls.sweep_parameters()
num_cases = self.ensembler_cls.get_case_ids(self.pred_dir)
logger.info(f"Running parameter sweep on {num_cases} cases to optimize "
f"{self.target_metric} with initial state {state}.")
best_score = float('-inf')
for param_name, values in sweep_params.items():
best_value, _best_score = self.run_parameter(
values=values,
param_name=param_name,
state=state,
)
state[param_name] = best_value
if _best_score < best_score:
logger.error("ERROR: Something went wrong during sweeping. "
"Results were modified inplace! "
f"Previous: {best_score} now {_best_score}")
best_score = _best_score
logger.info(f"\n\n Determined {state} with best sweeping score {best_score} {self.target_metric}\n\n")
return state
def run_parameter(self,
values: Sequence[Any],
param_name: str,
state: Dict[str, Any],
):
"""
Evaluate parameters and select the best
Args:
values: values to evaluate
param_name: name of parameter
state: different state parameters
"""
cache = []
overview = {}
for value in values:
logger.info(f"Running sweep {param_name}={value}")
tic = time.perf_counter()
metric_scores = self._evaluate_value(state=state, **{param_name: value})
overview[f"{param_name}_{value}".replace(".", "_")] = {
"state": str(state),
"overwrite": {param_name: str(value)},
"scores": str(metric_scores),
}
cache.append(metric_scores[self.target_metric])
toc = time.perf_counter()
logger.info(f"Sweep took {toc - tic} s")
best_idx = np.argmax(cache)
best_value = values[best_idx]
best_score = cache[best_idx]
if self.save_dir is not None:
overview[f"best_{param_name}"] = {"value": str(best_value), "score": str(best_score)}
save_json(overview, self.save_dir / f"sweep_{param_name}.json")
return best_value, best_score
def _evaluate_value(self,
state: Dict[str, Any],
**overwrite,
):
"""
Evalaute a single value
Args:
state: state for ensembler
overwrite: state overwrites
Returns:
Dict: scalar metrics
"""
evaluator = self.evaluator_cls.create(classes=self.classes,
fast=True,
verbose=False,
save_dir=None,
)
for case_id in maybe_verbose_iterable(self.ensembler_cls.get_case_ids(self.pred_dir)):
ensembler = self.ensembler_cls.from_checkpoint(
base_dir=self.pred_dir, case_id=case_id, device=self.device,
)
ensembler.update_parameters(**state)
ensembler.update_parameters(**overwrite)
pred = to_numpy(ensembler.get_case_result(restore=False))
gt = np.load(str(self.gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
evaluator.run_online_evaluation(
pred_boxes=[pred["pred_boxes"]], pred_classes=[pred["pred_labels"]],
pred_scores=[pred["pred_scores"]], gt_boxes=[gt["boxes"]],
gt_classes=[gt["classes"]], gt_ignore=None,
)
metric_scores, _ = evaluator.finish_online_evaluation()
return metric_scores
SweeperType = TypeVar('SweeperType', bound=Sweeper)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from loguru import logger
from typing import List, Tuple, Sequence
from nndet.io.transforms import Mirror, NoOp
from rising.transforms import AbstractTransform
def get_tta_transforms(num_tta_transforms: int, seg: bool = True) -> Tuple[
List[AbstractTransform], List[AbstractTransform]]:
"""
Get tta transformations
Args:
num_tta_transforms: number of tta transformations; 0: no tta, 4: augments
all directions in 2D, 8: augments all directions in 3D
Returns:
List[AbstractTransform]: transforms for TTA
List[AbstractTransform]: inverted transformations for TTA
"""
transforms = [NoOp()]
inverse_transforms = [NoOp()]
mirror_keys = ["data"]
pred_mirror_keys = ["pred_seg"] if seg else ["pred_seg"]
boxes_mirror_keys = ["pred_boxes"]
if num_tta_transforms >= 4:
logger.info("Adding 2D Mirror TTA for prediction.")
transforms.append(Mirror(keys=mirror_keys, dims=(0,)))
transforms.append(Mirror(keys=mirror_keys, dims=(1,)))
transforms.append(Mirror(keys=mirror_keys, dims=(0, 1)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(0,)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(1,)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(0, 1)))
if num_tta_transforms >= 8:
logger.info("Adding 3D Mirror TTA for prediction.")
transforms.append(Mirror(keys=mirror_keys, dims=(2,)))
transforms.append(Mirror(keys=mirror_keys, dims=(0, 2)))
transforms.append(Mirror(keys=mirror_keys, dims=(1, 2)))
transforms.append(Mirror(keys=mirror_keys, dims=(0, 1, 2)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(2,)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(0, 2)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(1, 2)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(0, 1, 2)))
return transforms, inverse_transforms
class Inference2D(AbstractTransform):
def __init__(self,
keys: Sequence[str],
):
"""
Helper transform to run inference with 2d models
Args:
keys: data keys to remove dimension from for inference
"""
super().__init__(grad=False)
self.keys = keys
def forward(self, **data) -> dict:
"""
Removes first spatial dimension (N, C, [removed], ax1, ax2)
Args:
**data: intput batch
Returns:
dict: transformed batch
"""
for key in self.keys:
data[key] = data[key][:, :, 0]
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment