Commit 7246044d authored by mibaumgartner's avatar mibaumgartner
Browse files

Merge remote-tracking branch 'origin/master' into main

parents fcec502f 6f4c3333
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from os import PathLike
from pathlib import Path
from typing import Dict, Optional, Tuple, Sequence, Any
import torch
import numpy as np
from loguru import logger
from scipy.ndimage import gaussian_filter
from torch import Tensor
from nndet.inference.ensembler.base import BaseEnsembler
from nndet.inference.restore import restore_fmap
class SegmentationEnsembler(BaseEnsembler):
ID = "seg"
def __init__(self,
seg_key: str = 'pred_seg',
data_key: str = 'data',
**kwargs,
):
"""
Ensemble segmentation predictions from tta and model ensembling
Args:
seg_key: key where segmentation is located inside prediction dict
data_key: key where data is located inside batch dict
use_gaussian: apply gaussian weighting to individual crops
non_lin: non linearity to apply to convert logits to probabilities
(will be applied after consolidation)
argmax: apply argmax to output
kwargs: passed to super class
"""
super().__init__(**kwargs)
self.seg_key = seg_key
self.data_key = data_key
self.model_results: Optional[Tensor] = None
self.overlap = torch.zeros(self.properties["shape"])
self.cache_crop_weight: Dict[Tuple, torch.Tensor] = {}
@classmethod
def from_case(cls,
case: Dict,
properties: Dict,
parameters: Optional[Dict] = None,
seg_key: str = 'pred_seg',
data_key: str = 'data',
**kwargs,
):
"""
Primary way to instantiate this class. Automatically extracts all
properties and uses a default set of parameters for ensembling.
Args:
case: case which is predicted.
mode: operation mode of ensembler (defines which network was used)
e.g. '2d' | '3d'
properties: Additional properties.
Required keys:
`transpose_backward`
`spacing_after_resampling`
`crop_bbox`
parameters: Additional parameters. Defaults to None.
seg_key: key where segmentation is located inside prediction dict
data_key: key where data is located inside batch dict
"""
parameters = parameters if parameters is not None else {}
_parameters = {"use_gaussian": True, "argmax": True}
_parameters.update(parameters)
_properties = {
"shape": case[data_key].shape[1:], # remove channel dim
"transpose_backward": properties["transpose_backward"],
"original_spacing": properties["original_spacing"],
"spacing_after_resampling": properties["spacing_after_resampling"],
"crop_bbox": properties["crop_bbox"],
"size_after_cropping": properties["size_after_cropping"],
"original_size_before_cropping": properties["original_size_of_raw_data"],
"itk_origin": properties["itk_origin"],
"itk_spacing": properties["itk_spacing"],
"itk_direction": properties["itk_direction"],
}
return cls(
properties=_properties,
parameters=_parameters,
seg_key=seg_key,
data_key=data_key,
**kwargs,
)
def add_model(self,
name: Optional[str] = None,
model_weight: Optional[float] = None,
) -> str:
"""
This functions signales the ensembler to add a new model for internal
processing
Args:
name: Name of the model. If None, uses counts the models.
model_weight: Optional weight for this model. Defaults to None.
"""
if name is None:
name = len(self.model_weights) + 1
if name in self.model_weights:
raise ValueError(f"Invalid model name, model {name} is already present")
if model_weight is None:
model_weight = 1.0
self.model_weights[name] = model_weight
self.model_current = name
return name
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any], Dict[str, Sequence[Any]]]:
"""
Not available for segmentation ensembler
Returns:
Dict[str, Sequence[Any]]: Parameters to sweep. The keys define the
parameters wile the Sequences are the values to sweep.
"""
return {}, {}
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.seg_key`: [Tensor]: predicted segmentation [N, C, dims]
batch: input batch
`tile_origin: origin of crop with recard to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
"""
seg_batch = result[self.seg_key].cpu()
crops = batch["crop"]
weight = self.get_weighting(tuple(seg_batch.shape[2:])).to(seg_batch)
seg_batch = seg_batch * weight[None].to(seg_batch) * self.model_weights[self.model_current]
if self.model_results is None:
self.model_results = torch.zeros(
(int(seg_batch.shape[1]), *self.properties["shape"])).to(seg_batch)
for seg, crop in zip(seg_batch, zip(*crops)):
_weight = weight.clone()
seg, case_crop = self.crop_to_case_boundaries(seg, crop)
_weight, case_crop2 = self.crop_to_case_boundaries(_weight[None], crop)
assert case_crop == case_crop2
self.model_results[case_crop] += seg
self.overlap[case_crop] += _weight[0]
def crop_to_case_boundaries(self, seg: torch.Tensor, crop: Sequence[slice]):
"""
In case padding was used at the borders, the padding needs to be removed
Args
seg: predicted segmentation
Sequence[slice]: crop in case to save segmentation
"""
if len(crop) > self.model_results.ndim - 1:
crop = crop[-(self.model_results.ndim - 1):]
crop_slicer = []
case_slicer = []
for dim, c in enumerate(crop):
case_start = max(0, c.start)
case_stop = min(self.model_results.shape[dim + 1], c.stop)
diff_stop = c.stop - self.model_results.shape[dim + 1]
crop_start = max(0, 0 - (c.start - 0)) # 0 added for completeness of pattern
crop_stop = min(seg.shape[dim + 1], seg.shape[dim + 1] - diff_stop)
crop_slicer.append(slice(crop_start, crop_stop, c.step))
case_slicer.append(slice(case_start, case_stop, c.step))
return seg[(..., *crop_slicer)], (..., *case_slicer)
def get_weighting(self, crop_size: Tuple[int]) -> torch.Tensor:
"""
Get matrix to weight predictions inside a single patch
Args:
crop_size: size of crop
Returns:
Tensor: weight for crop
Tuple[int]: size of patch
"""
if crop_size not in self.cache_crop_weight:
if self.parameters["use_gaussian"]:
logger.info(f"Creating new gaussian weight matrix for crop size {crop_size}")
tmp = np.zeros(crop_size)
center_coords = [i // 2 for i in crop_size]
sigmas = [i // 8 for i in crop_size]
tmp[tuple(center_coords)] = 1
tmp_smooth = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
tmp_smooth = tmp_smooth / tmp_smooth.max() * 1
weighting = tmp_smooth + 1e-8
self.cache_crop_weight[crop_size] = torch.from_numpy(weighting).float()
else:
logger.info(f"Creating new weight matrix for crop size {crop_size}")
self.cache_crop_weight[crop_size] = torch.ones(crop_size, dtype=torch.float)
return self.cache_crop_weight[crop_size]
def restore_prediction(self, logit_maps: Tensor) -> Tensor:
"""
Restore predictions in the original image space
Args:
boxes: predicted boxes [N, dims * 2] (x1, y1, x2, y2, (z1, z2))
Returns:
Tensor: boxes in original image space [N, dims * 2]
(x1, y1, x2, y2, (z1, z2))
"""
_old_dtype = logit_maps.dtype
logit_maps_np = restore_fmap(
fmap=logit_maps.detach().cpu().numpy(),
transpose_backward=self.properties["transpose_backward"],
original_spacing=self.properties["original_spacing"],
spacing_after_resampling=self.properties["spacing_after_resampling"],
original_size_before_cropping=self.properties["original_size_before_cropping"],
size_after_cropping=self.properties["size_after_cropping"],
crop_bbox=self.properties["crop_bbox"],
interpolation_order=1,
interpolation_order_z=0,
do_separate_z=None,
)
logit_maps = torch.from_numpy(logit_maps_np).to(dtype=_old_dtype)
return logit_maps
@torch.no_grad()
def get_case_result(self,
restore: bool = False, **kwargs
) -> Dict[str, Tensor]:
"""
Get final result for case after ensembling and TTA
Returns:
Dict: results
`pred_seg`: [C, dims] if :param:`self.argmax`
if False and [dims] if True
`restore`: indicate whether predictions were restored in
original image space
`itk_origin`: itk origin of image before preprocessing
`itk_spacing`: itk spacing of image before preprocessing
`itk_direction`: itk direction of image before preprocessing
"""
result = self.model_results / self.overlap[None]
if restore:
result = self.restore_prediction(result)
if self.parameters["argmax"]:
result = result.argmax(dim=0).to(dtype=torch.uint8)
self.case_result = result
return {
"pred_seg": self.case_result,
"restore": restore,
"itk_origin": self.properties["itk_origin"],
"itk_spacing": self.properties["itk_spacing"],
"itk_direction": self.properties["itk_direction"],
}
def save_state(self,
target_dir: Path,
name: str,
**kwargs,
):
"""
Save case result as pickle file. Identifier of ensembler will
be added to the name
Args:
target_dir: folder to save result to
name: name of case
"""
super().save_state(
target_dir=target_dir,
name=name,
seg_key=self.seg_key,
data_key=self.data_key,
case_crop_weight=self.cache_crop_weight,
**kwargs,
)
@classmethod
def from_checkpoint(cls, base_dir: PathLike, case_id: str, **kwargs):
ckp = torch.load(str(Path(base_dir) / f"{case_id}_{cls.ID}.pt"))
t = cls(
properties=ckp["properties"],
parameters=ckp["parameters"],
seg_key=ckp["seg_key"],
data_key=ckp["data_key"],
)
t._load(ckp)
return t
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from pathlib import Path
from typing import Sequence, List, Dict, Callable, Optional
import numpy as np
from loguru import logger
from nndet.utils.tensor import to_numpy
from nndet.io.load import load_pickle, save_pickle
from nndet.io.paths import Pathlike, get_case_id_from_path
from nndet.inference.loading import load_final_model
def predict_dir(
source_dir: Pathlike,
target_dir: Pathlike,
cfg: dict,
plan: dict,
source_models: Path,
model_fn: Callable[[Path, dict, dict, int], Sequence[dict]] = load_final_model,
num_models: int = None,
num_tta_transforms: int = None,
restore: bool = False,
case_ids: Optional[Sequence[str]] = None,
save_state: bool = False,
**kwargs
):
"""
Predict all preprocessed(!) cases inside a directory
Args:
source_dir: directory where preprocessed cases are located
target_dir: directory to save predictions to
cfg: config
`predictor`: define predictor to use
plan: plan
source_models: directory where models for prediction are located
model_fn: function to load model from directory
num_models: number of models to use for prediction; None = all
num_tta_transforms: number of tta transforms to use for
prediction; None = all
stage: current stage to predict
restore: restore predictions in original image space
case_ids: case ids to predict. If None the whole folder will be
predicted
save_state: If `true` the state of the ensembler is saved. If
`false` only the final result is saved.
kwargs: passed to :method:'get_predictor' method of module
"""
logger.info("Running inference")
source_dir = Path(source_dir)
target_dir = Path(target_dir)
models = model_fn(source_models, cfg, plan, num_models)
predictor = models[0]["model"].get_predictor(
plan=plan,
models=[m["model"] for m in models],
num_tta_transforms=num_tta_transforms,
**kwargs,
)
if case_ids is None:
case_paths = list(source_dir.glob('*.npz'))
case_paths = [cp for cp in case_paths if "_gt.npz" not in str(cp)]
else:
case_paths = [source_dir / f"{cid}.npz" for cid in case_ids]
logger.info(f"Found {len(case_paths)} files for inference.")
for idx, path in enumerate(case_paths, start=1):
logger.info(f"Predicting case {idx} of {len(case_paths)}.")
case_id = get_case_id_from_path(str(path), remove_modality=False)
if path.is_file():
case = np.load(str(path), allow_pickle=True)['data']
else:
case = np.load(str(path)[:-4] + ".npy", allow_pickle=True)
properties = load_pickle(path.parent / f"{case_id}.pkl")
properties["transpose_backward"] = plan["transpose_backward"]
if save_state:
_ = predictor.predict_case({"data": case},
properties,
save_dir=target_dir,
case_id=case_id,
restore=restore,
)
else:
result = predictor.predict_case({"data": case},
properties,
save_dir=None,
case_id=None,
restore=restore,
)
for key, item in to_numpy(result).items():
save_pickle(item, target_dir / f"{case_id}_{key}.pkl")
return predictor
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from functools import partial
from pathlib import Path
from typing import Sequence, Optional
import torch
from loguru import logger
from nndet.ptmodule import MODULE_REGISTRY
from nndet.io.paths import Pathlike
def get_loader_fn(mode: str, **kwargs):
if mode.lower() == "all":
load_fn = load_all_models
else:
load_fn = partial(load_final_model, identifier=mode, **kwargs)
return load_fn
def get_latest_model(base_dir: Pathlike, fold: int = 0) -> Optional[Path]:
"""
Get the latest training dir in a given base dir
E.g. ../RetinaUNetV0/fold0__0, ../RetinaUNetV0/fold0__1
-> would select fold0__1
Args:
base_dir: path to base dir
fold: fold to look for
Returns:
Optional[Path]: If no model for specified fold is found, None this
will return None
"""
base_dir = Path(base_dir)
m = [m for m in base_dir.iterdir() if m.is_dir()]
m = [_m for _m in m if f"fold{fold}" in _m.stem]
if m:
return sorted(m, key=lambda x: x.stem, reverse=True)[0]
else:
return None
def load_final_model(
source_models: Path,
cfg: dict,
plan: dict,
num_models: int = 1,
identifier: str = "last",
) -> Sequence[dict]:
"""
Load final model from training
Args:
source_models: path to directory where models are saved
cfg: config used for experiment
`model`: name of model in DETECTION_REGISTRY
plan: plan used for training
num_models: Only supports one model
identifier: looks for identifier inside of model name
Returns:
Sequence[dict]: loaded models
`model`: loaded model
`rank`: rank is always 0
"""
assert num_models == 1, f"load_final_model only supports num_models=1, found {num_models}"
logger.info(f"Loading {identifier} model")
model_names = list(source_models.glob('*.ckpt'))
model_names = [m for m in model_names if identifier in str(m.stem)]
assert len(model_names) == 1, f"Found wrong number of models, {model_names} in {source_models} with {identifier}"
path = model_names[0]
model = MODULE_REGISTRY[cfg["module"]](
model_cfg=cfg["model_cfg"],
trainer_cfg=cfg["trainer_cfg"],
plan=plan,
)
state_dict = torch.load(path, map_location="cpu")["state_dict"]
t = model.load_state_dict(state_dict)
logger.info(f"Loaded {path} with {t}")
model.float()
model.eval()
return [{"model": model, "rank": 0}]
def load_all_models(
source_models: Path,
cfg: dict,
plan: dict,
*args,
**kwargs,
):
"""
Load all models to ensemble
Args:
source_models: path to directory where models are saved
cfg: config used for experiment
`model`: name of model in DETECTION_REGISTRY
plan: plan used for training
kwargs: not used
Returns:
Sequence[dict]: loaded models
`model`: loaded model
`rank`: rank of model
"""
model_names = list(source_models.glob('*.ckpt'))
if not model_names:
raise RuntimeError(f"Did not find any models in {source_models}")
logger.info(f"Found {len(model_names)} models to ensemble")
models = []
for path in model_names:
model = MODULE_REGISTRY[cfg["module"]](
model_cfg=cfg["model_cfg"],
trainer_cfg=cfg["trainer_cfg"],
plan=plan,
)
state_dict = torch.load(path, map_location="cpu")["state_dict"]
t = model.load_state_dict(state_dict)
logger.info(f"Loaded {path} with {t}")
model.float()
model.eval()
models.append({"model": model.cpu()})
return models
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import time
import copy
import collections
import numpy as np
import torch
from torch.utils.data import DataLoader
from loguru import logger
from typing import Hashable, List, Sequence, Dict, Union, Any, Optional, Callable, TypeVar
from pathlib import Path
from nndet.io.load import save_pickle
from nndet.arch.abstract import AbstractModel
from nndet.io.transforms import NoOp
from nndet.io.transforms.base import AbstractTransform
from nndet.io.patching import save_get_crop, create_grid
from nndet.utils import to_device, maybe_verbose_iterable
torch_device = Union[torch.device, str]
class Predictor:
def __init__(self,
ensembler: Dict[str, Callable],
models: Sequence[AbstractModel],
crop_size: Sequence[int],
overlap: float = 0.5,
tile_keys: Sequence[str] = ('data',),
model_keys: Sequence[str] = ('data',),
tta_transforms: Sequence[AbstractTransform] = (NoOp(),),
tta_inverse_transforms: Sequence[AbstractTransform] = (NoOp(),),
pre_transform: AbstractTransform = None,
post_transform: AbstractTransform = None,
batch_size: int = 4,
model_weights: Sequence[float] = None,
device: torch_device = "cuda:0",
ensemble_on_device: bool = True,
):
"""
Predict entire cases with TTA and Model-Ensembling
Workflow
- Load whole patient
-> create predictor from patient
- tile patient
* for each model:
* for each batch (batches of tiles):
* for each tta transform:
- pre transform
- tta transform
- post transform
- predict batch
- inverse tta transform
- forward predictions and batch to ensembler classes
<- return patient result
Args:
ensembler: Callable to instantiate ensembler from case and
properties
models: models to ensemble
crop_size: size of each crop (for most cases this should be
the same as in training)
overlap: overlap of crops
tile_keys: keys which are tiles
model_keys: this kyes are passed as positional arugments to the
model
tta_transforms: tta transformations
tta_inverse_transforms: inverse tta transformation
pre_transform: transform which is performed before every tta
transform
post_transform: transform which is performed after every tta
transform
batch_size: batch size to use for prediction
model_weights: additional weighting of individual models
device: device used for prediction
ensemble_on_device: The results will be passed to the ensembler
class with the current device. The ensembler needs to make
sure to avoid memory leaks!
"""
self.ensemble_on_device = ensemble_on_device
self.device = device
self.ensembler_fns = ensembler
self.ensembler = {}
self.models = models
self.model_weights = [1.] * len(models) if model_weights is None else model_weights
self.crop_size = crop_size
self.overlap = overlap
self.tile_keys = tile_keys
self.model_keys = model_keys
self.batch_size = batch_size
if len(tta_transforms) != len(tta_inverse_transforms):
raise ValueError("Every tta transform needs a reverse transform")
self.tta_transforms = tta_transforms
self.tta_inverse_transforms = tta_inverse_transforms
self.post_transform = post_transform
self.pre_transform = pre_transform
self.grid_mode = 'symmetric'
self.save_get_mode = 'shift'
@classmethod
def create(cls, *args, **kwargs):
"""
Create predictor object with specific ensembler objects
Raises:
NotImplementedError: Need to be overwritten in subclasses
"""
raise NotImplementedError
@classmethod
def get_ensembler(cls, key: Hashable, dim: int) -> Callable:
"""
Return ensembler class for specific keys
Typically: `boxes`, `seg`, `instances`
Args:
key: Key to return
dim: number of spatial dimensions the network expects
Raises:
NotImplementedError: Need to be overwritten in subclasses
Returns:
Callable: Ensembler class
"""
raise NotImplementedError
def predict_case(self,
case: Dict,
properties: Optional[Dict] = None,
save_dir: Optional[Union[Path, str]] = None,
case_id: Optional[str] = None,
restore: bool = False,
) -> dict:
"""
Load and predict a single case.
Args:
case: data of a single case
properties: additional properties of the case. E.g. to
restore prediction in original image space
save_dir: directory to save predictions
case_id: used for saving
restore: restore prediction in original image space
("revert" preprocessing)
Returns:
dict: result of each ensembler (converted to numpy)
"""
tic = time.perf_counter()
for name, fn in self.ensembler_fns.items():
self.ensembler[name] = fn(case, properties=properties)
tiles = self.tile_case(case)
self.predict_tiles(tiles)
result = {key: value.get_case_result(restore=restore) for key, value in self.ensembler.items()}
if save_dir is not None:
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
for ensembler in self.ensembler.values():
ensembler.save_state(save_dir, name=case_id)
save_pickle(properties, save_dir / f"{case_id}_properties.pkl")
toc = time.perf_counter()
logger.info(f"Prediction took {toc - tic} s")
return result
def tile_case(self, case: dict, update_remaining: bool = True) -> \
Sequence[Dict[str, np.ndarray]]:
"""
Create patches from whole patient for prediction
Args:
case: data of a single case
update_remaining: properties from case which are not tiles
are saved into all patches
Returns:
Sequence[Dict[str, np.ndarray]]: extracted crops from case
and added new key:
`tile_origin`: Sequence[int] offset of tile relative
to case origin
"""
dshape = case[self.tile_keys[0]].shape
overlap = [int(c * self.overlap) for c in self.crop_size]
crops = create_grid(
cshape=self.crop_size,
dshape=dshape[1:],
overlap=overlap,
mode=self.grid_mode,
)
tiles = []
for crop in crops:
try:
# try selected extraction mode
tile = {key: save_get_crop(case[key], crop, mode=self.save_get_mode)[0]
for key in self.tile_keys}
_, tile["tile_origin"], tile["crop"] = save_get_crop(
case[self.tile_keys[0]], crop, mode=self.save_get_mode)
except RuntimeError:
# fallback to symmetric
logger.warning("Path size is bigger than whole case, padding case to match patch size")
tile = {key: save_get_crop(case[key], crop, mode="symmetric")[0]
for key in self.tile_keys}
_, tile["tile_origin"], tile["crop"] = save_get_crop(
case[self.tile_keys[0]], crop, mode="symmetric")
if update_remaining:
tile.update({key: item for key, item in case.items()
if key not in self.tile_keys})
tiles.append(tile)
return tiles
@torch.no_grad()
def predict_tiles(self, tiles: Sequence[Dict]) -> None:
"""
Predict tiles of a single case with ensembling and tta. Results
are saved inside ensemblers
Args:
tiles: tiles from single case
"""
dataloader = DataLoader(tiles,
batch_size=self.batch_size,
shuffle=False,
collate_fn=slice_collate,
)
for model_idx, (model, model_weight) in enumerate(
zip(self.models, self.model_weights)):
logger.info(f"Predicting model {model_idx + 1} of "
f"{len(self.models)} with weight {model_weight}.")
model.to(device=self.device)
model.eval()
for t, (transform, inverse_transform) in enumerate(maybe_verbose_iterable(
list(zip(self.tta_transforms, self.tta_inverse_transforms)),
desc="Transform", position=0)):
for ensembler in self.ensembler.values():
ensembler.add_model(name=f"model{model_idx}_t{t}", model_weight=model_weight)
for batch_num, batch in enumerate(maybe_verbose_iterable(
dataloader, desc="Crop", position=1)):
self.predict_with_transformation(
model=model,
batch=batch,
batch_num=batch_num,
transform=transform,
inverse_transform=inverse_transform,
)
model.cpu()
torch.cuda.empty_cache()
def predict_with_transformation(self,
model: AbstractModel,
batch: Dict,
batch_num: int,
transform: Callable,
inverse_transform: Callable,
):
"""
Run prediction with the specified transformations
Args:
model: model to predict
batch: input batch to model
batch_num: batch index
transform: transform to apply to batch.
inverse_transform: inverse transform to apply to batch and resuls
"""
batch = to_device(batch, device=self.device)
if self.pre_transform is not None:
batch = self.pre_transform(**batch)
transformed = transform(**batch)
if self.post_transform is not None:
transformed = self.post_transform(**transformed)
inp = [transformed[key] for key in self.model_keys]
with torch.cuda.amp.autocast():
result = model.inference_step(*inp, batch_num=batch_num)
result = inverse_transform(**result)
if not self.ensemble_on_device:
result = to_device(result, device="cpu")
for ensembler in self.ensembler.values():
ensembler.process_batch(result=result, batch=batch)
def slice_collate(batch: List[Any]):
"""
Add support for slices to collate function
Args:
batch: batch to collate
Returns:
Any: collated items
"""
elem = batch[0]
elem_type = type(elem)
if isinstance(batch[0], slice):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: slice_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(slice_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
transposed = zip(*batch)
return [slice_collate(samples) for samples in transposed]
else:
return torch.utils.data._utils.collate.default_collate(batch)
PredictorType = TypeVar('PredictorType', bound=Predictor)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence, Tuple, Union, Optional
import numpy as np
from loguru import logger
from nndet.core.boxes.ops import permute_boxes, expand_to_boxes
from nndet.preprocessing.resampling import resample_data_or_seg, get_do_separate_z, get_lowres_axis
def restore_detection(boxes: np.ndarray,
transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
crop_bbox: Sequence[Tuple[int, int]],
**kwargs,
) -> np.ndarray:
"""
Restore boxes from preprocessed space into original space
Args:
boxes: predicted boxes in preprocessing space,
(x1, y1, x2, y2, (z1, z2))[N, dims * 2]
transpose_backward: backward transposing
original_spacing: spacing of the original image
spacing_after_resampling: spacing in the preprocessed space
(forward transposed order)
crop_bbox: bounding box crop in the original spacing [(min_0, max_0), ...]
**kwargs: ignored
Returns:
np.ndarray: predicted bounding boxes in the original image space
"""
boxes_transposed = permute_boxes(boxes, transpose_backward)
original_spacing = np.asarray(original_spacing)
spacing_after_resampling = np.asarray(spacing_after_resampling)
resampled_spacing = spacing_after_resampling[transpose_backward]
scaling = resampled_spacing / original_spacing
scaling_expanded = expand_to_boxes(scaling[None])
boxes_scaled = boxes_transposed * scaling_expanded
offset = np.asarray([i[0] for i in crop_bbox])
offset_expanded = expand_to_boxes(offset[None])
boxes_original = boxes_scaled + offset_expanded
return boxes_original
def restore_fmap(fmap: np.ndarray,
transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
original_size_before_cropping: Sequence[int],
size_after_cropping: Sequence[int],
crop_bbox: Optional[Sequence[Tuple[int, int]]] = None,
interpolation_order: int = 3,
interpolation_order_z: int = 0,
do_separate_z: bool = None,
) -> np.ndarray:
"""
Restore feature map from preprocessed space into original space
Args:
fmap: feature map to resample [C, dims], where C is the number of
channels
transpose_backward: backward transposing
original_spacing: spacing of the original image
spacing_after_resampling: spacing in the preprocessed space
(forward transposed order)
original_size_before_cropping: original image size before cropping
size_after_cropping: image size after cropping
crop_bbox: bounding box of crop
interpolation_order: interpolation order for inplane axis
interpolation_order_z: interpolation order for anisotropic axis
do_separate_z: if None then we dynamically decide how to resample
along z, if True/False then always/never resample along z
separately. Do not touch unless you know what you are doin
Returns:
np.ndarray: resampled feature map [C, new_dims]
"""
fmap_transposed = np.transpose(fmap, [0] + [i + 1 for i in transpose_backward])
original_spacing = np.asarray(original_spacing)
spacing_after_resampling = np.asarray(spacing_after_resampling)
resampled_spacing = spacing_after_resampling[transpose_backward]
if np.any([i != j for i, j in zip(fmap.shape[1:], size_after_cropping)]):
lowres_axis = _get_lowres_axes(original_spacing, resampled_spacing,
do_separate_z=do_separate_z)
logger.info(f"Resampling: do separate z: {do_separate_z}; lowres axis: {lowres_axis}")
fmap_old_spacing = resample_data_or_seg(fmap, size_after_cropping, is_seg=False,
axis=lowres_axis, order=interpolation_order,
do_separate_z=do_separate_z, cval=0,
order_z=interpolation_order_z)
else:
logger.info(f"Resampling: no resampling necessary")
fmap_old_spacing = fmap_transposed
if crop_bbox is not None:
crop_bbox = [list(cb) for cb in crop_bbox]
tmp = np.zeros((fmap_old_spacing.shape[0], *original_size_before_cropping))
for c in range(len(crop_bbox)):
crop_bbox[c][1] = np.min(
(crop_bbox[c][0] + fmap_old_spacing.shape[c + 1], original_size_before_cropping[c]))
_slices = [...] + [slice(b[0], b[1]) for b in crop_bbox]
tmp[_slices] = fmap_old_spacing
fmap_original = tmp
else:
fmap_original = fmap_old_spacing
return fmap_original
def _get_lowres_axes(original_spacing: Sequence[float],
resampled_spacing: Sequence[float],
do_separate_z: bool) -> Union[Sequence[int], None]:
"""
Dynamically determine lowres axes
Args:
original_spacing: original spacing (not transposed!)
resampled_spacing: resampled sapcing (not transposed!)
do_separate_z: force sepearte
Returns:
Union[Sequence[int], None]: Lowres axes. If None, no lowres axes
is present
"""
if do_separate_z is None:
if get_do_separate_z(original_spacing): # original spacing was anisotropic
do_separate_z = True
lowres_axis = get_lowres_axis(original_spacing)
elif get_do_separate_z(resampled_spacing): # resampled spacing was anisotropic
do_separate_z = True
lowres_axis = get_lowres_axis(resampled_spacing)
else: # no separate z
do_separate_z = False
lowres_axis = None
else:
if do_separate_z:
lowres_axis = get_lowres_axis(original_spacing)
else:
lowres_axis = None
return lowres_axis
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABC, abstractmethod
from pathlib import Path
import time
from typing import Callable, Tuple, Dict, Sequence, Any, Optional, TypeVar
import numpy as np
from loguru import logger
from nndet.io.paths import Pathlike
from nndet.io.load import save_json
from nndet.utils.info import maybe_verbose_iterable
from nndet.utils import to_numpy
from nndet.evaluator.registry import BoxEvaluator
class Sweeper(ABC):
def __init__(self,
classes: Sequence[str],
pred_dir: Pathlike,
gt_dir: Pathlike,
target_metric: str,
save_dir: Optional[Pathlike] = None,
):
"""
Sweep multiple parameters and compute evaluation metrics
to determine the best set of parameters
Args:
evaluation: reference to an evaluation objects
pred_dir: directory where predicted data is saved
device: device to use for internal computations
"""
self.classes = classes
self.save_dir = save_dir if save_dir is None else Path(save_dir)
if self.save_dir is not None:
self.save_dir.mkdir(parents=True, exist_ok=True)
self.target_metric = target_metric
self.device = "cpu"
self.pred_dir = Path(pred_dir)
self.gt_dir = Path(gt_dir)
@abstractmethod
def run_postprocessing_sweep(self,
restore: bool = True,
) -> Tuple[Dict, Dict]:
"""
Run parameter sweeps to determine best parameters
accoring to target metric
Args:
target_metric: metric to optimize
Returns:
Dict: determined parameters
Dict: final results with parameters
"""
raise NotImplementedError
class BoxSweeper(Sweeper):
def __init__(self,
classes: Sequence[str],
pred_dir: Pathlike,
gt_dir: Pathlike,
target_metric: str,
ensembler_cls: Callable,
save_dir: Optional[Pathlike] = None,
) -> None:
"""
Run sweep over parameters and select the best
Args:
classes: classes present in dataset
pred_dir: directory where predictions are saved
gt_dir: directory where ground truth is saved
target_metric: metric to optimize
ensembler_cls: ensembler class used during prediction
save_dir: Directory to save results. Defaults to None.
"""
super().__init__(classes=classes,
pred_dir=pred_dir,
gt_dir=gt_dir,
target_metric=target_metric,
save_dir=save_dir,
)
self.evaluator_cls = BoxEvaluator
self.ensembler_cls = ensembler_cls
def run_postprocessing_sweep(self):
"""
Sequentially search for the best parameters
Returns:
Dict: final parameters to run inference on new cases
Dict:
`det_scores`: detection score metrics
`det_curves`: detection curves
"""
state, sweep_params = self.ensembler_cls.sweep_parameters()
num_cases = self.ensembler_cls.get_case_ids(self.pred_dir)
logger.info(f"Running parameter sweep on {num_cases} cases to optimize "
f"{self.target_metric} with initial state {state}.")
best_score = float('-inf')
for param_name, values in sweep_params.items():
best_value, _best_score = self.run_parameter(
values=values,
param_name=param_name,
state=state,
)
state[param_name] = best_value
if _best_score < best_score:
logger.error("ERROR: Something went wrong during sweeping. "
"Results were modified inplace! "
f"Previous: {best_score} now {_best_score}")
best_score = _best_score
logger.info(f"\n\n Determined {state} with best sweeping score {best_score} {self.target_metric}\n\n")
return state
def run_parameter(self,
values: Sequence[Any],
param_name: str,
state: Dict[str, Any],
):
"""
Evaluate parameters and select the best
Args:
values: values to evaluate
param_name: name of parameter
state: different state parameters
"""
cache = []
overview = {}
for value in values:
logger.info(f"Running sweep {param_name}={value}")
tic = time.perf_counter()
metric_scores = self._evaluate_value(state=state, **{param_name: value})
overview[f"{param_name}_{value}".replace(".", "_")] = {
"state": str(state),
"overwrite": {param_name: str(value)},
"scores": str(metric_scores),
}
cache.append(metric_scores[self.target_metric])
toc = time.perf_counter()
logger.info(f"Sweep took {toc - tic} s")
best_idx = np.argmax(cache)
best_value = values[best_idx]
best_score = cache[best_idx]
if self.save_dir is not None:
overview[f"best_{param_name}"] = {"value": str(best_value), "score": str(best_score)}
save_json(overview, self.save_dir / f"sweep_{param_name}.json")
return best_value, best_score
def _evaluate_value(self,
state: Dict[str, Any],
**overwrite,
):
"""
Evalaute a single value
Args:
state: state for ensembler
overwrite: state overwrites
Returns:
Dict: scalar metrics
"""
evaluator = self.evaluator_cls.create(classes=self.classes,
fast=True,
verbose=False,
save_dir=None,
)
for case_id in maybe_verbose_iterable(self.ensembler_cls.get_case_ids(self.pred_dir)):
ensembler = self.ensembler_cls.from_checkpoint(
base_dir=self.pred_dir, case_id=case_id, device=self.device,
)
ensembler.update_parameters(**state)
ensembler.update_parameters(**overwrite)
pred = to_numpy(ensembler.get_case_result(restore=False))
gt = np.load(str(self.gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
evaluator.run_online_evaluation(
pred_boxes=[pred["pred_boxes"]], pred_classes=[pred["pred_labels"]],
pred_scores=[pred["pred_scores"]], gt_boxes=[gt["boxes"]],
gt_classes=[gt["classes"]], gt_ignore=None,
)
metric_scores, _ = evaluator.finish_online_evaluation()
return metric_scores
SweeperType = TypeVar('SweeperType', bound=Sweeper)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from loguru import logger
from typing import List, Tuple, Sequence
from nndet.io.transforms import Mirror, NoOp
from nndet.io.transforms.base import AbstractTransform
def get_tta_transforms(num_tta_transforms: int, seg: bool = True) -> Tuple[
List[AbstractTransform], List[AbstractTransform]]:
"""
Get tta transformations
Args:
num_tta_transforms: number of tta transformations; 0: no tta, 4: augments
all directions in 2D, 8: augments all directions in 3D
Returns:
List[AbstractTransform]: transforms for TTA
List[AbstractTransform]: inverted transformations for TTA
"""
transforms = [NoOp()]
inverse_transforms = [NoOp()]
mirror_keys = ["data"]
pred_mirror_keys = ["pred_seg"] if seg else ["pred_seg"]
boxes_mirror_keys = ["pred_boxes"]
if num_tta_transforms >= 4:
logger.info("Adding 2D Mirror TTA for prediction.")
transforms.append(Mirror(keys=mirror_keys, dims=(0,)))
transforms.append(Mirror(keys=mirror_keys, dims=(1,)))
transforms.append(Mirror(keys=mirror_keys, dims=(0, 1)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(0,)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(1,)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(0, 1)))
if num_tta_transforms >= 8:
logger.info("Adding 3D Mirror TTA for prediction.")
transforms.append(Mirror(keys=mirror_keys, dims=(2,)))
transforms.append(Mirror(keys=mirror_keys, dims=(0, 2)))
transforms.append(Mirror(keys=mirror_keys, dims=(1, 2)))
transforms.append(Mirror(keys=mirror_keys, dims=(0, 1, 2)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(2,)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(0, 2)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(1, 2)))
inverse_transforms.append(Mirror(keys=pred_mirror_keys,
box_keys=boxes_mirror_keys, dims=(0, 1, 2)))
return transforms, inverse_transforms
class Inference2D(AbstractTransform):
def __init__(self,
keys: Sequence[str],
):
"""
Helper transform to run inference with 2d models
Args:
keys: data keys to remove dimension from for inference
"""
super().__init__(grad=False)
self.keys = keys
def forward(self, **data) -> dict:
"""
Removes first spatial dimension (N, C, [removed], ax1, ax2)
Args:
**data: intput batch
Returns:
dict: transformed batch
"""
for key in self.keys:
data[key] = data[key][:, :, 0]
return data
from nndet.io.load import (
load_json,
load_pickle,
save_json,
save_pickle,
npy_dataset,
save_yaml,
)
from nndet.io.paths import (
get_case_id_from_file,
get_case_id_from_path,
get_case_ids_from_dir,
get_paths_from_splitted_dir,
get_paths_raw_to_split,
get_task, get_training_dir,
)
from nndet.io.itk import (
load_sitk,
load_sitk_as_array,
)
from typing import Mapping, Type
from nndet.io.augmentation.base import AugmentationSetup
from nndet.utils.registry import Registry
AUGMENTATION_REGISTRY: Mapping[str, Type[AugmentationSetup]] = Registry()
from nndet.io.augmentation.bg_aug import (
NoAug,
DefaultAug,
BaseMoreAug,
MoreAug,
InsaneAug,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence, List
from abc import ABC, abstractmethod
import numpy as np
def get_patch_size(
patch_size: Sequence[int],
rot_x: float,
rot_y: float,
rot_z: float,
scale_range: Sequence[float],
) -> np.ndarray:
"""
Compute enlarged patch size for augmentations to reduce
artifacts at the borders before final cropping
Args:
final_patch_size: target spatial size after final cropping
rot_x: rotation in x in radian
rot_y: rotation in y in radian
rot_z: rotation in z in radian
scale_range: scaling range
Returns:
np.ndarray: enlarged patch size for augmentation
"""
if isinstance(rot_x, (tuple, list)):
rot_x = max(np.abs(rot_x))
if isinstance(rot_y, (tuple, list)):
rot_y = max(np.abs(rot_y))
if isinstance(rot_z, (tuple, list)):
rot_z = max(np.abs(rot_z))
rot_x = min(90 / 360 * 2. * np.pi, rot_x)
rot_y = min(90 / 360 * 2. * np.pi, rot_y)
rot_z = min(90 / 360 * 2. * np.pi, rot_z)
from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
coords = np.array(patch_size)
final_shape = np.copy(coords)
if len(coords) == 3:
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
elif len(coords) == 2:
final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
final_shape /= min(scale_range)
return final_shape.astype(np.int32)
class AugmentationSetup(ABC):
def __init__(self,
patch_size: Sequence[int],
params: dict,
) -> None:
"""
Helper class for augmenation setup
Args:
patch_size: output patch size of augmentations
params: augmentation parameters
Notes:
The needed keys of :attr:`params` depend on the exact
transformations which should be used.
"""
self.patch_size = patch_size
self.params = params
@abstractmethod
def get_training_transforms(self):
"""
Setup training transformations
Needs to be overwritten in subclasses.
"""
raise NotImplementedError
@abstractmethod
def get_validation_transforms(self):
"""
Setup validation transformations
Needs to be overwritten in subclasses.
"""
raise NotImplementedError
def get_patch_size_generator(self) -> List[int]:
"""
Compute patch size to extract from volume to avoid augmentation
artifacts
"""
return list(get_patch_size(
patch_size=self.patch_size,
rot_x=self.params['rotation_x'],
rot_y=self.params['rotation_y'],
rot_z=self.params['rotation_z'],
scale_range=self.params['scale_range'],
))
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence, List
from loguru import logger
from nndet.io.augmentation.base import AugmentationSetup, get_patch_size
from batchgenerators.transforms import (
DataChannelSelectionTransform,
SegChannelSelectionTransform,
SpatialTransform,
GammaTransform,
MirrorTransform,
Compose,
BrightnessMultiplicativeTransform,
ContrastAugmentationTransform,
GaussianNoiseTransform,
GaussianBlurTransform,
SimulateLowResolutionTransform,
RenameTransform,
NumpyToTensor,
CenterCropTransform,
)
from batchgenerators.transforms.color_transforms import BrightnessTransform
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform
from nnunet.training.data_augmentation.custom_transforms import (
Convert3DTo2DTransform,
Convert2DTo3DTransform,
MaskTransform,
)
from nndet.io.augmentation import AUGMENTATION_REGISTRY
@AUGMENTATION_REGISTRY.register
class NoAug(AugmentationSetup):
def __init__(self, patch_size: Sequence[int], params: dict) -> None:
super().__init__(patch_size, params)
self.dummy_2d = self.params.get("dummy_2D", False)
if self.dummy_2d:
logger.info("Running dummy 2d augmentation transforms!")
if self.dummy_2d:
self._spatial_transform_patch_size = self.patch_size[1:]
else:
self._spatial_transform_patch_size = self.patch_size
def get_patch_size_generator(self) -> List[int]:
"""
Compute patch size to extract from volume to avoid augmentation
artifacts
"""
_patch_size = list(get_patch_size(
patch_size=self._spatial_transform_patch_size,
rot_x=self.params['rotation_x'],
rot_y=self.params['rotation_y'],
rot_z=self.params['rotation_z'],
scale_range=self.params['scale_range'],
))
if self.dummy_2d:
_patch_size = [self.patch_size[0]] + _patch_size
return _patch_size
def get_training_transforms(self):
tr_transforms = []
if self.params.get("selected_data_channels"):
tr_transforms.append(DataChannelSelectionTransform(
self.params.get("selected_data_channels")))
if self.params.get("selected_seg_channels"):
tr_transforms.append(SegChannelSelectionTransform(
self.params.get("selected_seg_channels")))
tr_transforms.append(CenterCropTransform(self.patch_size))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
return Compose(tr_transforms)
def get_validation_transforms(self):
val_transforms = []
if self.params.get("selected_data_channels"):
val_transforms.append(DataChannelSelectionTransform(
self.params.get("selected_data_channels")))
if self.params.get("selected_seg_channels"):
val_transforms.append(SegChannelSelectionTransform(
self.params.get("selected_seg_channels")))
val_transforms.append(CenterCropTransform(self.patch_size))
val_transforms.append(RemoveLabelTransform(-1, 0))
val_transforms.append(RenameTransform('seg', 'target', True))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
return Compose(val_transforms)
@AUGMENTATION_REGISTRY.register
class DefaultAug(NoAug):
def get_training_transforms(self):
assert self.params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if self.params.get("selected_data_channels"):
tr_transforms.append(DataChannelSelectionTransform(
self.params.get("selected_data_channels")))
if self.params.get("selected_seg_channels"):
tr_transforms.append(SegChannelSelectionTransform(
self.params.get("selected_seg_channels")))
if self.params.get("dummy_2D", False):
# don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
tr_transforms.append(Convert3DTo2DTransform())
tr_transforms.append(SpatialTransform(
self._spatial_transform_patch_size,
patch_center_dist_from_border=None,
do_elastic_deform=self.params.get("do_elastic"),
alpha=self.params.get("elastic_deform_alpha"),
sigma=self.params.get("elastic_deform_sigma"),
do_rotation=self.params.get("do_rotation"),
angle_x=self.params.get("rotation_x"),
angle_y=self.params.get("rotation_y"),
angle_z=self.params.get("rotation_z"),
do_scale=self.params.get("do_scaling"),
scale=self.params.get("scale_range"),
order_data=self.params.get("order_data"),
border_mode_data=self.params.get("border_mode_data"),
border_cval_data=self.params.get("border_cval_data"),
order_seg=self.params.get("order_seg"),
border_mode_seg=self.params.get("border_mode_seg"),
border_cval_seg=self.params.get("border_cval_seg"),
random_crop=self.params.get("random_crop"),
p_el_per_sample=self.params.get("p_eldef"),
p_scale_per_sample=self.params.get("p_scale"),
p_rot_per_sample=self.params.get("p_rot"),
independent_scale_for_each_axis=self.params.get("independent_scale_factor_for_each_axis"),
))
if self.params.get("dummy_2D", False):
tr_transforms.append(Convert2DTo3DTransform())
if self.params.get("do_gamma", False):
tr_transforms.append(
GammaTransform(self.params.get("gamma_range"), False, True,
retain_stats=self.params.get("gamma_retain_stats"),
p_per_sample=self.params["p_gamma"])
)
if self.params.get("do_mirror", False):
tr_transforms.append(MirrorTransform(self.params.get("mirror_axes")))
if self.params.get("use_mask_for_norm"):
use_mask_for_norm = self.params.get("use_mask_for_norm")
tr_transforms.append(MaskTransform(use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
return Compose(tr_transforms)
@AUGMENTATION_REGISTRY.register
class BaseMoreAug(NoAug):
def get_training_transforms(self):
assert self.params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if self.params.get("selected_data_channels"):
tr_transforms.append(DataChannelSelectionTransform(
self.params.get("selected_data_channels")))
if self.params.get("selected_seg_channels"):
tr_transforms.append(SegChannelSelectionTransform(
self.params.get("selected_seg_channels")))
# don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
if self.params.get("dummy_2D", False):
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
else:
ignore_axes = None
tr_transforms.append(SpatialTransform(
self._spatial_transform_patch_size,
patch_center_dist_from_border=None,
do_elastic_deform=self.params.get("do_elastic"),
alpha=self.params.get("elastic_deform_alpha"),
sigma=self.params.get("elastic_deform_sigma"),
do_rotation=self.params.get("do_rotation"),
angle_x=self.params.get("rotation_x"),
angle_y=self.params.get("rotation_y"),
angle_z=self.params.get("rotation_z"),
do_scale=self.params.get("do_scaling"),
scale=self.params.get("scale_range"),
order_data=self.params.get("order_data"),
border_mode_data=self.params.get("border_mode_data"),
border_cval_data=self.params.get("border_cval_data"),
order_seg=self.params.get("order_seg"),
border_mode_seg=self.params.get("border_mode_seg"),
border_cval_seg=self.params.get("border_cval_seg"),
random_crop=self.params.get("random_crop"),
p_el_per_sample=self.params.get("p_eldef"),
p_scale_per_sample=self.params.get("p_scale"),
p_rot_per_sample=self.params.get("p_rot"),
independent_scale_for_each_axis=self.params.get("independent_scale_factor_for_each_axis"),
))
if self.params.get("dummy_2D"):
tr_transforms.append(Convert2DTo3DTransform())
# we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
# channel gets in the way
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
tr_transforms.append(GaussianBlurTransform((0.5, 1.),
different_sigma_per_channel=True,
p_per_sample=0.2,
p_per_channel=0.5))
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25),
p_per_sample=0.15))
if self.params.get("do_additive_brightness"):
tr_transforms.append(BrightnessTransform(
self.params.get("additive_brightness_mu"),
self.params.get("additive_brightness_sigma"),
True,
p_per_sample=self.params.get("additive_brightness_p_per_sample"),
p_per_channel=self.params.get("additive_brightness_p_per_channel")))
tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
tr_transforms.append(GammaTransform(
self.params.get("gamma_range"), True, True, retain_stats=self.params.get("gamma_retain_stats"),
p_per_sample=0.1)) # inverted gamma
if self.params.get("do_gamma"):
tr_transforms.append(GammaTransform(
self.params.get("gamma_range"),
False,
True,
retain_stats=self.params.get("gamma_retain_stats"),
p_per_sample=self.params["p_gamma"]))
if self.params.get("do_mirror") or self.params.get("mirror"):
tr_transforms.append(MirrorTransform(self.params.get("mirror_axes")))
if self.params.get("use_mask_for_norm"):
use_mask_for_norm = self.params.get("use_mask_for_norm")
tr_transforms.append(MaskTransform(use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
return Compose(tr_transforms)
@AUGMENTATION_REGISTRY.register
class MoreAug(NoAug):
def get_training_transforms(self):
assert self.params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if self.params.get("selected_data_channels"):
tr_transforms.append(DataChannelSelectionTransform(
self.params.get("selected_data_channels")))
if self.params.get("selected_seg_channels"):
tr_transforms.append(SegChannelSelectionTransform(
self.params.get("selected_seg_channels")))
# don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
if self.params.get("dummy_2D", False):
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
else:
ignore_axes = None
tr_transforms.append(SpatialTransform(
self._spatial_transform_patch_size,
patch_center_dist_from_border=None,
do_elastic_deform=self.params.get("do_elastic"),
alpha=self.params.get("elastic_deform_alpha"),
sigma=self.params.get("elastic_deform_sigma"),
do_rotation=self.params.get("do_rotation"),
angle_x=self.params.get("rotation_x"),
angle_y=self.params.get("rotation_y"),
angle_z=self.params.get("rotation_z"),
do_scale=self.params.get("do_scaling"),
scale=self.params.get("scale_range"),
order_data=self.params.get("order_data"),
border_mode_data=self.params.get("border_mode_data"),
border_cval_data=self.params.get("border_cval_data"),
order_seg=self.params.get("order_seg"),
border_mode_seg=self.params.get("border_mode_seg"),
border_cval_seg=self.params.get("border_cval_seg"),
random_crop=self.params.get("random_crop"),
p_el_per_sample=self.params.get("p_eldef"),
p_scale_per_sample=self.params.get("p_scale"),
p_rot_per_sample=self.params.get("p_rot"),
independent_scale_for_each_axis=self.params.get("independent_scale_factor_for_each_axis"),
))
if self.params.get("dummy_2D"):
tr_transforms.append(Convert2DTo3DTransform())
# we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
# channel gets in the way
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
tr_transforms.append(GaussianBlurTransform((0.5, 1.),
different_sigma_per_channel=True,
p_per_sample=0.2,
p_per_channel=0.5))
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25),
p_per_sample=0.15))
if self.params.get("do_additive_brightness"):
tr_transforms.append(BrightnessTransform(
self.params.get("additive_brightness_mu"),
self.params.get("additive_brightness_sigma"),
True,
p_per_sample=self.params.get("additive_brightness_p_per_sample"),
p_per_channel=self.params.get("additive_brightness_p_per_channel")))
tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1),
per_channel=True,
p_per_channel=0.5,
order_downsample=0,
order_upsample=3,
p_per_sample=0.25,
ignore_axes=ignore_axes,
))
tr_transforms.append(GammaTransform(
self.params.get("gamma_range"),
True,
True,
retain_stats=self.params.get("gamma_retain_stats"),
p_per_sample=0.1)) # inverted gamma
if self.params.get("do_gamma"):
tr_transforms.append(GammaTransform(
self.params.get("gamma_range"),
False,
True,
retain_stats=self.params.get("gamma_retain_stats"),
p_per_sample=self.params["p_gamma"]))
if self.params.get("do_mirror") or self.params.get("mirror"):
tr_transforms.append(MirrorTransform(self.params.get("mirror_axes")))
if self.params.get("use_mask_for_norm"):
use_mask_for_norm = self.params.get("use_mask_for_norm")
tr_transforms.append(MaskTransform(use_mask_for_norm,
mask_idx_in_seg=0,
set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
return Compose(tr_transforms)
@AUGMENTATION_REGISTRY.register
class InsaneAug(NoAug):
def get_training_transforms(self):
assert self.params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if self.params.get("selected_data_channels"):
tr_transforms.append(DataChannelSelectionTransform(
self.params.get("selected_data_channels")))
if self.params.get("selected_seg_channels"):
tr_transforms.append(SegChannelSelectionTransform(
self.params.get("selected_seg_channels")))
# don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
if self.params.get("dummy_2D", False):
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
else:
ignore_axes = None
tr_transforms.append(SpatialTransform(
self._spatial_transform_patch_size,
patch_center_dist_from_border=None,
do_elastic_deform=self.params.get("do_elastic"),
alpha=self.params.get("elastic_deform_alpha"),
sigma=self.params.get("elastic_deform_sigma"),
do_rotation=self.params.get("do_rotation"),
angle_x=self.params.get("rotation_x"),
angle_y=self.params.get("rotation_y"),
angle_z=self.params.get("rotation_z"),
do_scale=self.params.get("do_scaling"),
scale=self.params.get("scale_range"),
order_data=self.params.get("order_data"),
border_mode_data=self.params.get("border_mode_data"),
border_cval_data=self.params.get("border_cval_data"),
order_seg=self.params.get("order_seg"),
border_mode_seg=self.params.get("border_mode_seg"),
border_cval_seg=self.params.get("border_cval_seg"),
random_crop=self.params.get("random_crop"),
p_el_per_sample=self.params.get("p_eldef"),
p_scale_per_sample=self.params.get("p_scale"),
p_rot_per_sample=self.params.get("p_rot"),
independent_scale_for_each_axis=self.params.get("independent_scale_factor_for_each_axis"),
))
if self.params.get("dummy_2D"):
tr_transforms.append(Convert2DTo3DTransform())
# we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
# channel gets in the way
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
tr_transforms.append(GaussianBlurTransform((0.5, 1.5),
different_sigma_per_channel=True,
p_per_sample=0.2,
p_per_channel=0.5),
)
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3),
p_per_sample=0.15))
if self.params.get("do_additive_brightness"):
tr_transforms.append(BrightnessTransform(
self.params.get("additive_brightness_mu"),
self.params.get("additive_brightness_sigma"),
True,
p_per_sample=self.params.get("additive_brightness_p_per_sample"),
p_per_channel=self.params.get("additive_brightness_p_per_channel")))
tr_transforms.append(ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
p_per_sample=0.15))
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1),
per_channel=True,
p_per_channel=0.5,
order_downsample=0,
order_upsample=3,
p_per_sample=0.25,
ignore_axes=ignore_axes),
)
tr_transforms.append(GammaTransform(
self.params.get("gamma_range"),
True,
True,
retain_stats=self.params.get("gamma_retain_stats"),
p_per_sample=0.15)) # inverted gamma
if self.params.get("do_gamma"):
tr_transforms.append(GammaTransform(
self.params.get("gamma_range"),
False,
True,
retain_stats=self.params.get("gamma_retain_stats"),
p_per_sample=self.params["p_gamma"]))
if self.params.get("do_mirror") or self.params.get("mirror"):
tr_transforms.append(MirrorTransform(self.params.get("mirror_axes")))
if self.params.get("use_mask_for_norm"):
use_mask_for_norm = self.params.get("use_mask_for_norm")
tr_transforms.append(MaskTransform(use_mask_for_norm,
mask_idx_in_seg=0,
set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
return Compose(tr_transforms)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import shutil
import pickle
import numpy as np
from loguru import logger
from multiprocessing.pool import Pool
from pathlib import Path
from typing import List, Tuple, Sequence
from scipy.ndimage import binary_fill_holes
from nndet.io.paths import get_case_id_from_path
from nndet.io.load import load_case_from_list
def create_nonzero_mask(data: np.ndarray) -> np.ndarray:
"""
Create a nonzero mask from data
Args:
data (np.ndarray): input data [C, X, Y, Z]
Returns:
np.ndarray: binary mask on nonzero regions [X, Y, Z]
"""
assert len(data.shape) == 4 or len(data.shape) == 3, \
"data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.max(data != 0, axis=0)
nonzero_mask = binary_fill_holes(nonzero_mask.astype(bool))
return nonzero_mask
def get_bbox_from_mask(mask: np.ndarray, outside_value: int = 0) -> List[Tuple]:
"""
Create a bounding box from a mask
Args:
mask (np.ndarray): mask [X, Y, Z]
outside_value (int): background value
Returns:
np.ndarray: [(dim0_min, dim0_max), (dim1_min, dim1_max), (dim2_min, dim2_max))
"""
mask_voxel_coords = (mask != outside_value).nonzero()
min0idx = int(np.min(mask_voxel_coords[0]))
max0idx = int(np.max(mask_voxel_coords[0])) + 1
min1idx = int(np.min(mask_voxel_coords[1]))
max1idx = int(np.max(mask_voxel_coords[1])) + 1
idx = [(min0idx, max0idx), (min1idx, max1idx)]
if len(mask_voxel_coords) == 3:
min2idx = int(np.min(mask_voxel_coords[2]))
max2idx = int(np.max(mask_voxel_coords[2])) + 1
idx.append((min2idx, max2idx))
return idx
def crop_to_bbox_no_channels(image, bbox: Sequence[Sequence[int]]):
"""
Crops image to bounding box (in spatial dimensions)
Args:
image (arraylike): 2d or 3d array
bbox (Sequence[Sequence[int]]): bounding box coordinated in an interleaved fashion
(e.g. (x1, x2), (y1, y2), (z1, z2))
Returns:
arraylike: cropped array
"""
resizer = tuple([slice(_dim[0], _dim[1]) for _dim in bbox])
return image[resizer]
def crop_to_bbox(data: np.ndarray, bbox: Sequence[Sequence[int]]):
"""
Crops image to bounding box (performed per channel)
Args:
data (np.ndarray): 3d or 4d array [C, X, Y, (Z)]
bbox (Sequence[Sequence[int]]): bounding box coordinated in an interleaved fashion
(e.g. (x1, x2), (y1, y2), (z1, z2))
Returns:
np.ndarray: cropped array
"""
cropped_data = []
for c in range(data.shape[0]):
cropped = crop_to_bbox_no_channels(data[c], bbox)
cropped_data.append(cropped)
data = np.stack(cropped_data)
return data
def crop_to_nonzero(data, seg=None, nonzero_label=-1):
"""
Crop data to nonzero region of data
Args:
data (np.ndarray): data to crop
seg (np.ndarray): segmenation
nonzero_label (int): nonzero label is written into segmentation map
where only background was found
Returns:
np.ndarray: cropped data
np.ndarray: cropped and filled (with nonzero_label) segmentation
List[Tuple[int]]: bounding box of nonzero region
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask, 0)
data = crop_to_bbox(data, bbox)
if seg is not None:
seg = crop_to_bbox(seg, bbox)
nonzero_mask = crop_to_bbox_no_channels(nonzero_mask, bbox)[None]
if seg is not None:
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(np.int32)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask
return data, seg, bbox
class ImageCropper(object):
def __init__(self, num_processes: int, output_dir: Path = None):
"""
Helper class to crop images to non zero region (must hold for all modalities)
In the case of BRaTS and ISLES data this results in a significant reduction in image size
Args:
num_processes (int): number of processes to use for cropping
output_dir (Path): path to output directory
"""
self.output_dir = Path(output_dir) if output_dir is not None else None
self.num_processes = num_processes
self.maybe_init_output_dir()
def maybe_init_output_dir(self):
"""
Create output directory if it does not already exists
"""
if self.output_dir is not None and not self.output_dir.is_dir():
self.output_dir.mkdir()
def run_cropping(self, case_files: List[List[Path]], overwrite_existing: bool = False,
output_dir: Path = None, copy_gt_data: bool = True):
"""
Crops data to non zero region and saves them into output_dir
Optional: also copies ground truth data
Args:
case_files (List[List[Path]]): list with all cases in the structure [Case[Case Files]];
where case files are sorted to corresponding modalities (last file is the label file)
overwrite_existing (bool): overwrite existing crops
output_dir (Path): path to output directory
copy_gt_data (bool): copies ground truth data to output directory
"""
if output_dir is not None:
self.output_dir = Path(output_dir)
self.maybe_init_output_dir()
if copy_gt_data:
self.copy_gt_data(case_files)
list_of_args = []
for _i, case in enumerate(case_files):
case_id = get_case_id_from_path(str(case[0]))
assert not case_id.endswith(".gz") and not case_id.endswith(".nii")
list_of_args.append((case, case_id, overwrite_existing))
with Pool(processes=self.num_processes) as p:
p.map(self._process_data_star, list_of_args)
def copy_gt_data(self, case_files: List[List[Path]]):
"""
Copy ground truth to output directory
"""
output_dir_gt = self.output_dir / "labelsTr"
if not output_dir_gt.is_dir():
output_dir_gt.mkdir()
for j, case in enumerate(case_files):
if case[-1] is not None:
shutil.copy(case[-1], output_dir_gt)
def _process_data_star(self, args):
"""
Unpack argument for function
"""
return self.process_data(*args)
def process_data(self, case: List[Path], case_id: str, overwrite_existing: bool = False):
"""
Extract nonzero region from all cases and create a single array where segmentation
is located in the last channel and save as npz (saved in key `data`)
Additional properties per case are saved inside a pkl file
Args:
case (List[Path]): list of paths to data and label (label is always at the last position
and data is sorted after modalities)
case_id (str): case identifier
overwrite_existing (bool): overwrite existing data
"""
try:
logger.info(f"Processing case {case_id}")
npz_exists = (self.output_dir / f"{case_id}.npz").is_file()
pkl_exists = (self.output_dir / f"{case_id}.pkl").is_file()
if (not npz_exists and not pkl_exists) or overwrite_existing:
data, seg, properties = self.load_crop_from_list_of_files(case[:-1], case[-1])
all_data = np.vstack((data, seg))
np.savez_compressed(self.output_dir / f"{case_id}.npz", data=all_data)
with open(self.output_dir / f"{case_id}.pkl", 'wb') as f:
pickle.dump(properties, f)
else:
logger.warning(f"Case {case_id} already exists and overwrite is deactivated")
except Exception as e:
logger.info(f"exception in: {case_id}: {e}")
raise e
@staticmethod
def load_crop_from_list_of_files(data_files: List[Path], seg_file: Path = None):
"""
Load and crop form list of files
Args:
data_files (List[Path]): paths to data files
seg_file (Path): pth to segmentation
Returns:
np.ndarray: cropped data
np.ndarray: cropped (and filled segmentation: -1 where no forground exists) label
dict: additional properties
`original_size_of_raw_data`: original shape of data (correctly reordered)
`original_spacing`: original spacing (correctly reordered)
`list_of_data_files`: paths of data files
`seg_file`: path to label file
`itk_origin`: origin in world coordinates
`itk_spacing`: spacing in world coordinates
`itk_direction`: direction in world coordinates
`crop_bbox`: List[Tuple[int]] cropped bounding box
`classes`: present classes in segmentation
`size_after_cropping`: size after cropping
"""
data, seg, properties = load_case_from_list(data_files, seg_file)
return ImageCropper.crop(data, properties, seg)
@staticmethod
def crop(data: np.ndarray, properties: dict, seg: np.ndarray = None):
"""
Crop data and segmentation to non zero region
Args:
data (np.ndarray): data to crop [C, X, Y, Z]
properties (dict): additional properties
seg (np.ndarray): segmentation [1, X, Y, Z]
Returns:
data (np.ndarray): data to crop [C, X, Y, Z]
seg (np.ndarray): segmentation [1, X, Y, Z]
properties (dict): newly added properties
`crop_bbox`: List[Tuple[int]] cropped bounding box
`classes`: present classes in segmentation
`size_after_cropping`: size after cropping
"""
shape_before = data.shape
data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=-1)
shape_after = data.shape
# logger.info(f"Shape before crop {shape_before}; after crop {shape_after}; "
# f"spacing {np.array(properties['original_spacing'])}")
properties["crop_bbox"] = bbox
properties['classes'] = np.unique(seg)
seg[seg < -1] = 0
properties["size_after_cropping"] = data[0].shape
return data, seg, properties
from typing import Iterable, Mapping
from nndet.utils.registry import Registry
DATALOADER_REGISTRY: Mapping[str, Iterable] = Registry()
from nndet.io.datamodule.bg_loader import (
DataLoader3DFast,
DataLoader3DOffset,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from pathlib import Path
from collections import OrderedDict
import numpy as np
import pytorch_lightning as pl
from loguru import logger
from sklearn.model_selection import KFold
from nndet.io.utils import load_dataset_id
from nndet.io.load import load_pickle, save_pickle
class BaseModule(pl.LightningDataModule):
def __init__(self,
plan: dict,
augment_cfg: dict,
data_dir: os.PathLike,
fold: int = 0,
**kwargs,
):
"""
Baseclass for nnDetection data nodules.
Overwrite :method:`setup` to customize the bahvior.
The splits are created iniside the init because we
Args:
plan: plan file
augment_cfg: provide settings for augmentation
`splits_file` (str, optional): provide alternative splits file
data_dir: path to preprocessed data dir. Needs to follow:
`.../preprocessed/[data_identifier]/imagesTr
fold: current fold; if None, does not create folds and uses
whole dataset for training and validation (don't do this ...
except you know what you are doing :P)
"""
super().__init__(**kwargs)
self.plan = plan
self.augment_cfg = augment_cfg
self.data_dir = Path(data_dir)
self.fold = fold
self.preprocessed_dir = self.data_dir.parent.parent
self.splits_file = self.augment_cfg.get(
"splits_final", "splits_final.pkl")
self.dataset_tr = {}
self.dataset_val = {}
self.dataset = load_dataset_id(self.data_dir)
self.do_split()
@property
def splits_file(self) -> str:
return self._splits_file
@splits_file.setter
def splits_file(self, f: str) -> None:
if f.endswith("pkl"):
self._splits_file = f
else:
self._splits_file = f + ".pkl"
def do_split(self) -> None:
"""
Load a datasplit.
If not split is found, a new split is created.
Results are saved into :attr:`dataset_tr` and :attr:`dataset_val`
"""
splits_file = self.preprocessed_dir / self.splits_file
if not splits_file.is_file():
self.create_new_split(splits_file)
logger.info(f"Using splits {splits_file} with fold {self.fold}")
splits = load_pickle(splits_file)
if self.fold is None:
logger.warning(f"USING SAME TRAIN AND VAL SET")
tr_keys = val_keys = list(self.dataset.keys())
else:
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def create_new_split(self, splits_file: Path) -> None:
"""
Create a new 5 fold split with a fixed seed
Args:
splits_file: path where splits file should be saved
"""
logger.info("Creating new split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from pathlib import Path
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
from loguru import logger
from batchgenerators.dataloading import SlimDataLoaderBase
from nndet.io.datamodule import DATALOADER_REGISTRY
from nndet.io.load import load_pickle
from nndet.io.patching import save_get_crop
from nndet.utils.info import maybe_verbose_iterable
from nndet.core.boxes.ops_np import box_size_np
class FixedSlimDataLoaderBase(SlimDataLoaderBase):
def __init__(self,
*args,
num_batches_per_epoch: int = 2500,
**kwargs,
):
self.num_batches_per_epoch = num_batches_per_epoch
super().__init__(*args, **kwargs)
def __len__(self):
return self.num_batches_per_epoch
@DATALOADER_REGISTRY.register
class DataLoader3DFast(FixedSlimDataLoaderBase):
def __init__(self,
data: Dict,
batch_size: int,
patch_size_generator: Sequence[int],
patch_size_final: Sequence[int],
oversample_foreground_percent: float = 0.5,
memmap_mode: str = "r+",
pad_mode: str = "constant",
pad_kwargs_data: Optional[Dict[str, Any]] = None,
num_batches_per_epoch: int = 2500,
):
"""
Basic Dataloder for 3D Data.
Center of foreground patches is sampled from pre computed bounding
boxes. Background patches are sampled randomly. Cases are selected
randomly.
Args:
data: dict with cases and data paths
batch_size: size of batches to generate
patch_size_generator: patch size prduced by the dataloader
patch_size_final: final patch size after spatial transform
oversample_foreground_percent: Oversample foreground patches.
Each batch will be balanced to fullfill this criterion.
memmap_mode: Do not change this. Defaults to "r".
pad_mode: Padding mode for data. Defaults to "constant".
pad_kwargs_data: Addition kwargs for data padding. Defaults to None.
Raises:
ValueError: patch size of dataloder and final patch size need to
have the same length
"""
super().__init__(
data=data,
batch_size=batch_size,
number_of_threads_in_multithreaded=None,
num_batches_per_epoch=num_batches_per_epoch,
)
if len(patch_size_generator) != len(patch_size_final):
raise ValueError(f"Final and generator patch size need to have the same length."
f"Found generator {patch_size_generator} and "
f"final {patch_size_final} patch size.")
self.patch_size_generator = patch_size_generator
self.patch_size_final = patch_size_final
self.oversample_foreground_percent = oversample_foreground_percent
self.memmap_mode = memmap_mode
self.pad_mode = pad_mode
self.pad_kwargs_data = pad_kwargs_data if pad_kwargs_data is not None else {}
# we sample bigger patches and create a center crop during augmentation
# to cover the boarders of the patient we need to adjust the position
self.need_to_pad = (np.array(patch_size_generator) - np.array(patch_size_final)).astype(np.int32)
self.data_shape_batch, self.seg_shape_batch = self.determine_shapes()
self.cache = self.build_cache()
self.candidates_key = "boxes_file"
def determine_shapes(self) -> Tuple[Tuple[int], Tuple[int]]:
"""
Determines data and segmentation shape to preallocate arrays
during loading
Raises:
RuntimeError: Raised if data was not unpacked
Returns:
Tuple[Tuple[int], Tuple[int]]: Final shape of data,
Final shape of seg (including batchdim)
"""
k = list(self._data.keys())[0]
if (p := Path(self._data[k]['data_file'])).is_file():
data = np.load(str(p), self.memmap_mode, allow_pickle=False)
else:
raise RuntimeError("You shall not pass! Unpack data first!")
if (p := Path(self._data[k]['seg_file'])).is_file():
seg = np.load(str(p), self.memmap_mode, allow_pickle=False)
else:
raise RuntimeError("You shall not pass! Unpack data first!")
num_data_channels = data.shape[0]
num_seg_channels = seg.shape[0]
data_shape = (self.batch_size, num_data_channels, *self.patch_size_generator)
seg_shape = (self.batch_size, num_seg_channels, *self.patch_size_generator)
return data_shape, seg_shape
def build_cache(self) -> Dict[str, List]:
"""
Build up cache for sampling
Returns:
Dict[str, List]: cache for sampling
`case`: list with all case identifiers
`instances`: list with tuple of (case_id, instance_id)
"""
instance_cache = []
logger.info("Building Sampling Cache for Dataloder")
for case_id, item in maybe_verbose_iterable(self._data.items(), desc="Sampling Cache"):
instances = load_pickle(item['boxes_file'])["instances"]
if instances:
for instance_id in instances:
instance_cache.append((case_id, instance_id))
return {"case": list(self._data.keys()), "instances": instance_cache}
def select(self) -> Tuple[List, List]:
"""
Selects cases and instances. If instance id is -1 a random background
patch will be sampled.
Foreground sampling: sample uniformly from all the foreground classes
and enforce the respective class while patch sampling.
Background sampling: We jsut sample a random case
Returns:
List: case identifiers
List: instance ids
id > 0 indicates an instance
id = -1 indicates a random (background) patch
"""
selected_cases = []
selected_instances = []
for idx in range(self.batch_size):
if idx < round(self.batch_size * (1 - self.oversample_foreground_percent)):
# sample bg / random case
selected_cases.append(np.random.choice(self.cache["case"]))
selected_instances.append(-1)
else:
# sample fg / select an instance
idx = np.random.choice(range(len(self.cache["instances"])))
_case, _instance_id = self.cache["instances"][idx]
selected_cases.append(_case)
selected_instances.append(int(_instance_id))
return selected_cases, selected_instances
def generate_train_batch(self) -> Dict[str, Any]:
"""
Generate a single batch
Returns:
Dict: batch dict
`data` (np.ndarray): data
`seg` (np.ndarray): unordered(!) instance segmentation
Reordering needs to happen after final crop
`instances` (List[Sequence[int]]): class for each instance in
the case (<- we can not extract them because we do not
know the present instances yet)
`properties`(List[Dict]): properties of each case
`keys` (List[str]): case ids
"""
data_batch = np.zeros(self.data_shape_batch, dtype=float)
seg_batch = np.zeros(self.seg_shape_batch, dtype=float)
instances_batch, properties_batch, case_ids_batch = [], [], []
selected_cases, selected_instances = self.select()
for batch_idx, (case_id, instance_id) in enumerate(zip(selected_cases, selected_instances)):
# print(case_id, instance_id)
case_data = np.load(self._data[case_id]['data_file'], self.memmap_mode, allow_pickle=True)
case_seg = np.load(self._data[case_id]['seg_file'], self.memmap_mode, allow_pickle=True)
properties = load_pickle(self._data[case_id]['properties_file'])
if instance_id < 0:
candidates = self.load_candidates(case_id=case_id, fg_crop=False)
crop = self.get_bg_crop(
case_data=case_data,
case_seg=case_seg,
properties=properties,
case_id=case_id,
candidates=candidates,
)
else:
candidates = self.load_candidates(case_id=case_id, fg_crop=True)
crop = self.get_fg_crop(
case_data=case_data,
case_seg=case_seg,
properties=properties,
case_id=case_id,
instance_id=instance_id,
candidates=candidates,
)
data_batch[batch_idx] = save_get_crop(case_data,
crop=crop,
mode=self.pad_mode,
**self.pad_kwargs_data,
)[0]
seg_batch[batch_idx] = save_get_crop(case_seg,
crop=crop,
mode='constant',
constant_values=-1,
)[0]
case_ids_batch.append(case_id)
instances_batch.append(properties.pop("instances"))
properties_batch.append(properties)
return {'data': data_batch,
'seg': seg_batch,
'properties': properties_batch,
'instance_mapping': instances_batch,
'keys': case_ids_batch,
}
def load_candidates(self, case_id: str, fg_crop: bool) -> Union[Dict, None]:
"""
Load candidates for sampling
Args:
case_id: case id to load candidates from
fg_crop: True if foreground crop will be sampled, False if
background will be sampled
Returns:
Union[Dict, None]: dict if fg, None if bg
"""
if fg_crop:
return load_pickle(self._data[case_id]['boxes_file'])
else:
return None
def get_fg_crop(self,
case_data: np.ndarray,
case_seg: np.ndarray,
properties: dict,
case_id: str,
instance_id: int,
candidates: Union[Dict, None],
) -> List[slice]:
"""
Sample foreground patches from precomputed boxes
Args:
case_data: case data (this should be a memmap!)
case_seg: case segmentation (this should be a memmap!)
properties: properties of case
case_id: identifier of case
instance_id: instance index to sample
candidates: candidate positions to sample foreground from.
Should not be None for this case.
Returns:
List[slice]: determined crop
"""
assert candidates is not None
# some instances might get lost during resampling so we need to find the correct index
idx = candidates["instances"].index(instance_id)
box = candidates["boxes"][idx] # [6]
origin0 = np.random.randint(int(box[0]) + 1, int(box[2])) - (self.patch_size_generator[0] // 2)
origin1 = np.random.randint(int(box[1]) + 1, int(box[3])) - (self.patch_size_generator[1] // 2)
origin2 = np.random.randint(int(box[4]) + 1, int(box[5])) - (self.patch_size_generator[2] // 2)
return [slice(origin0, origin0 + self.patch_size_generator[0]),
slice(origin1, origin1 + self.patch_size_generator[1]),
slice(origin2, origin2 + self.patch_size_generator[2])]
def get_bg_crop(self,
case_data: np.ndarray,
case_seg: np.ndarray,
properties: dict,
case_id: str,
candidates: Union[Dict, None],
) -> List[slice]:
"""
Extract slices for (random) background crop
Args:
case_data: case data (this should be a memmap!)
case_seg: case segmentation (this should be a memmap!)
properties: properties of case
case_id: identifier of case
candidates: foreground candidates. Is not used in this
specific implementation and thus None
Returns:
List[slice]: determined crop
"""
data_shape = case_data.shape[1:]
crop = []
for ps, ds, _pad in zip(self.patch_size_generator, data_shape, self.need_to_pad):
pad = _pad
if pad + ds < ps:
pad = ps - ds
origin = np.random.randint(-(pad // 2), ds + (pad // 2) + (pad % 2) - ps + 1)
crop.append(slice(origin, origin + ps))
return crop
@DATALOADER_REGISTRY.register
class DataLoader3DOffset(DataLoader3DFast):
def get_fg_crop(self,
case_data: np.ndarray,
case_seg: np.ndarray,
properties: dict,
case_id: str,
instance_id: int,
candidates: Union[Dict, None],
) -> List[slice]:
"""
Sample foreground patches from precomputed boxes
Args:
case_data: case data (this should be a memmap!)
case_seg: case segmentation (this should be a memmap!)
properties: properties of case
case_id: identifier of case
instance_id: instance index to sample
candidates: candidate positions to sample foreground from.
Should not be None for this case.
Returns:
List[slice]: determined crop
"""
spatial_shape = case_data.shape[1:]
# some instances might get lost during resampling so we need to find the correct index
idx = candidates["instances"].index(instance_id)
box = candidates["boxes"][[idx]] # [1, 6]
box_size = box_size_np(box)[0]
box = box[0]
origins = []
for i, (ib, ib2) in enumerate([(0, 2), (1, 3), (4, 5)]):
if spatial_shape[i] <= self.patch_size_generator[i]: # patch larger than scan
# we center the slice and pad the rest
origins.append(- (self.need_to_pad[i] // 2))
elif box_size[i] >= self.patch_size_final[i]: # selected instance is larger than patch
# we can not offset, we select our center point inside the bounding box and hope for the best
center = np.random.randint(int(box[ib]) + 1, int(box[ib2]))
origins.append(center - (self.patch_size_generator[0] // 2))
else: # create best effort offset
patch_upper_bound = spatial_shape[i] - self.patch_size_final[i]
lower_bound = np.clip(box[ib] - (self.patch_size_final[i] - box_size[i]),
a_min=0, a_max=patch_upper_bound)
upper_bound = np.clip(box[ib], a_min=0, a_max=patch_upper_bound)
if lower_bound == upper_bound:
_origin = int(lower_bound)
else:
_origin = np.random.randint(lower_bound, upper_bound)
origins.append(_origin - (self.need_to_pad[i] // 2))
return [slice(origins[0], origins[0] + self.patch_size_generator[0]),
slice(origins[1], origins[1] + self.patch_size_generator[1]),
slice(origins[2], origins[2] + self.patch_size_generator[2]),
]
@DATALOADER_REGISTRY.register
class DataLoader3DBalanced(DataLoader3DOffset):
def build_cache(self) -> Tuple[Dict[int, List[Tuple[str, int]]], List]:
"""
Build up cache for sampling
Returns:
Dict[int, List[Tuple[str, int]]]: foreground cache which contains
of list of tuple of case ids and instance ids for each class
List: background cache (all samples which do not have any
foreground)
"""
fg_cache = defaultdict(list)
logger.info("Building Sampling Cache for Dataloder")
for case_id, item in maybe_verbose_iterable(self._data.items(), desc="Sampling Cache"):
candidates = load_pickle(item['boxes_file'])
if candidates["instances"]:
for instance_id, instance_class in zip(candidates["instances"], candidates["labels"]):
fg_cache[int(instance_class)].append((case_id, instance_id))
return {"fg": fg_cache, "case": list(self._data.keys())}
def select(self) -> Tuple[List, List]:
"""
Foreground sampling: sample uniformly from all the foreground classes
and enforce the respective class while patch sampling.
Background sampling: We jsut sample a random case
"""
selected_classes = np.random.choice(
list(self.cache["fg"].keys()), self.batch_size, replace=True)
selected_cases = []
selected_instances = []
for idx in range(len(selected_classes)):
if idx < round(self.batch_size * (1 - self.oversample_foreground_percent)):
# sample bg / random case
selected_cases.append(np.random.choice(self.cache["case"]))
selected_instances.append(-1)
else:
# sample fg / select an instance
_i = np.random.choice(range(len(self.cache["fg"][selected_classes[idx]])))
_case, _instance_id = self.cache["fg"][selected_classes[idx]][_i]
selected_cases.append(_case)
selected_instances.append(int(_instance_id))
return selected_cases, selected_instances
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from typing import Iterable, Optional, List, Sequence, Type
import numpy as np
from loguru import logger
from batchgenerators.dataloading import (
MultiThreadedAugmenter,
SingleThreadedAugmenter,
)
from nndet.io.augmentation import AUGMENTATION_REGISTRY
from nndet.io.datamodule import DATALOADER_REGISTRY
from nndet.io.augmentation.base import AugmentationSetup
from nndet.io.datamodule.base import BaseModule
class FixedLengthMultiThreadedAugmenter(MultiThreadedAugmenter):
def __len__(self):
return len(self.generator)
class FixedLengthSingleThreadedAugmenter(SingleThreadedAugmenter):
def __len__(self):
return len(self.data_loader)
def get_augmenter(dataloader,
transform,
num_processes: int,
num_cached_per_queue: int = 2,
multiprocessing: bool = True,
seeds: Optional[List[int]] = None,
pin_memory=True,
**kwargs,
):
"""
Wrapper to switch between multi-threaded and single-threaded augmenter
"""
if multiprocessing:
logger.info(f"Using {num_processes} num_processes "
f"and {num_cached_per_queue} num_cached_per_queue for augmentation.")
loader = FixedLengthMultiThreadedAugmenter(
data_loader=dataloader,
transform=transform,
num_processes=num_processes,
num_cached_per_queue=num_cached_per_queue,
seeds=seeds,
pin_memory=pin_memory,
**kwargs,
)
else:
loader = FixedLengthSingleThreadedAugmenter(
data_loader=dataloader,
transform=transform,
**kwargs,
)
return loader
class Datamodule(BaseModule):
def __init__(self,
plan: dict,
augment_cfg: dict,
data_dir: os.PathLike,
fold: int = 0,
**kwargs,
):
"""
Batchgenerator based datamodule
Args:
augment_cfg: provide settings for augmentation
`splits_file` (str, optional): provide alternative splits file
`oversample_foreground_percent` (float, optional):
ratio of foreground and background inside of batches,
defaults to 0.33
`patch_size`(Sequence[int], optional): overwrite patch size
`batch_size`(int, optional): overwrite patch size
plan: current plan
preprocessed_dir: path to base preprocessed dir
data_dir: path to preprocessed data dir
fold: current fold; if None, does not create folds and uses
whole dataset for training and validation (don't do this ...
except you know what you are doing :P)
"""
super().__init__(
plan=plan,
augment_cfg=augment_cfg,
data_dir=data_dir,
fold=fold,
**kwargs,
)
self.augmentation: Optional[Type[AugmentationSetup]] = None
self.patch_size_generator: Optional[Sequence[int]] = None
@property
def patch_size(self):
"""
Get patch size which can be (optionally) overwritten in the
augmentation config
"""
if "patch_size" in self.augment_cfg:
ps = self.augment_cfg["patch_size"]
logger.warning(f"Patch Size Overwrite Found: running patch size {ps}")
return np.array(ps).astype(np.int32)
else:
return np.array(self.plan['patch_size']).astype(np.int32)
@property
def batch_size(self):
"""
Get batch size which can be (optionally) overwritten in the
augmentation config
"""
if "batch_size" in self.augment_cfg:
bs = self.augment_cfg["batch_size"]
logger.warning(f"Batch Size Overwrite Found: running batch size {bs}")
return bs
else:
return self.plan["batch_size"]
@property
def dataloader(self):
"""
Get dataloader class name
"""
return self.augment_cfg['dataloader'].format(self.plan["network_dim"])
@property
def dataloader_kwargs(self):
"""
Get dataloader kwargs which can be (optionally) overwritten in the
augmentation config
"""
dataloader_kwargs = self.plan.get('dataloader_kwargs', {})
if dl_kwargs := self.augment_cfg.get("dataloader_kwargs", {}):
logger.warning(f"Dataloader Kwargs Overwrite Found: {dl_kwargs}")
dataloader_kwargs.update(dl_kwargs)
return dataloader_kwargs
def setup(self, stage: Optional[str] = None):
"""
Process augmentation configurations and plan to determine the
patch size, the patch size for the generator and create the
augmentation object.
"""
dim = len(self.patch_size)
params = self.augment_cfg["augmentation"]
patch_size = self.patch_size
if dim == 2:
logger.info("Using 2D augmentation params")
overwrites_2d = params.get("2d_overwrites", {})
params.update(overwrites_2d)
elif dim == 3 and self.plan['do_dummy_2D_data_aug']:
logger.info("Using dummy 2d augmentation params")
params["dummy_2D"] = True
params["elastic_deform_alpha"] = params["2d_overwrites"]["elastic_deform_alpha"]
params["elastic_deform_sigma"] = params["2d_overwrites"]["elastic_deform_sigma"]
params["rotation_x"] = params["2d_overwrites"]["rotation_x"]
params["selected_seg_channels"] = [0]
params["use_mask_for_norm"] = self.plan['use_mask_for_norm']
params["rotation_x"] = [i / 180 * np.pi for i in params["rotation_x"]]
params["rotation_y"] = [i / 180 * np.pi for i in params["rotation_y"]]
params["rotation_z"] = [i / 180 * np.pi for i in params["rotation_z"]]
augmentation_cls = AUGMENTATION_REGISTRY[params["transforms"]]
self.augmentation = augmentation_cls(
patch_size=patch_size,
params=params,
)
self.patch_size_generator = self.augmentation.get_patch_size_generator()
logger.info(f"Augmentation: {params['transforms']} transforms and "
f"{params.get('name', 'no_name')} params ")
logger.info(f"Loading network patch size {self.augmentation.patch_size} "
f"and generator patch size {self.patch_size_generator}")
def train_dataloader(self) -> Iterable:
"""
Create training dataloader
Returns:
Iterable: dataloader for training
"""
dataloader_cls = DATALOADER_REGISTRY.get(self.dataloader)
logger.info(f"Using training {self.dataloader} with {self.dataloader_kwargs}")
dl_tr = dataloader_cls(
data=self.dataset_tr,
batch_size=self.batch_size,
patch_size_generator=self.patch_size_generator,
patch_size_final=self.patch_size,
oversample_foreground_percent=self.augment_cfg[
"oversample_foreground_percent"],
pad_mode="constant",
num_batches_per_epoch=self.augment_cfg[
"num_train_batches_per_epoch"],
**self.dataloader_kwargs,
)
tr_gen = get_augmenter(
dataloader=dl_tr,
transform=self.augmentation.get_training_transforms(),
num_processes=min(int(self.augment_cfg.get('num_threads', 12)), 16) - 1,
num_cached_per_queue=self.augment_cfg.get('num_cached_per_thread', 2),
multiprocessing=self.augment_cfg.get("multiprocessing", True),
seeds=None,
pin_memory=True,
)
logger.info("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())))
return tr_gen
def val_dataloader(self):
"""
Create validation dataloader
Returns:
Iterable: dataloader for validation
"""
dataloader_cls = DATALOADER_REGISTRY.get(self.dataloader)
logger.info(f"Using validation {self.dataloader} with {self.dataloader_kwargs}")
dl_val = dataloader_cls(
data=self.dataset_val,
batch_size=self.batch_size,
patch_size_generator=self.patch_size,
patch_size_final=self.patch_size,
oversample_foreground_percent=self.augment_cfg[
"oversample_foreground_percent"],
pad_mode="constant",
num_batches_per_epoch=self.augment_cfg[
"num_val_batches_per_epoch"],
**self.dataloader_kwargs,
)
val_gen = get_augmenter(
dataloader=dl_val,
transform=self.augmentation.get_validation_transforms(),
num_processes=min(int(self.augment_cfg.get('num_threads', 12)), 16) - 1,
num_cached_per_queue=self.augment_cfg.get('num_cached_per_thread', 2),
multiprocessing=self.augment_cfg.get("multiprocessing", True),
seeds=None,
pin_memory=True,
)
logger.info("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())))
return val_gen
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from pathlib import Path
import numpy as np
import SimpleITK as sitk
from itertools import product
from typing import Sequence, Union, Tuple
def create_circle_mask_itk(image_itk: sitk.Image,
world_centers: Sequence[Sequence[float]],
world_rads: Sequence[float],
ndim: int = 3,
) -> sitk.Image:
"""
Creates an itk image with circles defined by center points and radii
Args:
image_itk: original image (used for the coordinate frame)
world_centers: Sequence of center points in world coordiantes (x, y, z)
world_rads: Sequence of radii to use
ndim: number of spatial dimensions
Returns:
sitk.Image: mask with circles
"""
image_np = sitk.GetArrayFromImage(image_itk)
min_spacing = min(image_itk.GetSpacing())
if image_np.ndim > ndim:
image_np = image_np[0]
mask_np = np.zeros_like(image_np).astype(np.uint8)
for _id, (world_center, world_rad) in enumerate(zip(world_centers, world_rads), start=1):
check_rad = (world_rad / min_spacing) * 1.5 # add some buffer to it
bounds = []
center = image_itk.TransformPhysicalPointToContinuousIndex(world_center)[::-1]
for ax, c in enumerate(center):
bounds.append((
max(0, int(c - check_rad)),
min(mask_np.shape[ax], int(c + check_rad)),
))
coord_box = product(*[list(range(b[0], b[1])) for b in bounds])
# loop over every pixel position
for coord in coord_box:
world_coord = image_itk.TransformIndexToPhysicalPoint(tuple(reversed(coord))) # reverse order to x, y, z for sitk
dist = np.linalg.norm(np.array(world_coord) - np.array(world_center))
if dist <= world_rad:
mask_np[tuple(coord)] = _id
assert mask_np.max() == _id
mask_itk = sitk.GetImageFromArray(mask_np)
return copy_meta_data_itk(image_itk, mask_itk)
def copy_meta_data_itk(source: sitk.Image, target: sitk.Image) -> sitk.Image:
"""
Copy meta data between files
Args:
source: source file
target: target file
Returns:
sitk.Image: target file with copied meta data
"""
# for i in source.GetMetaDataKeys():
# target.SetMetaData(i, source.GetMetaData(i))
raise NotImplementedError("Does not work!")
target.SetOrigin(source.GetOrigin())
target.SetDirection(source.GetDirection())
target.SetSpacing(source.GetSpacing())
return target
def load_sitk(path: Union[Path, str], **kwargs) -> sitk.Image:
"""
Functional interface to load image with sitk
Args:
path: path to file to load
Returns:
sitk.Image: loaded sitk image
"""
return sitk.ReadImage(str(path), **kwargs)
def load_sitk_as_array(path: Union[Path, str], **kwargs) -> Tuple[np.ndarray, dict]:
"""
Functional interface to load sitk image and convert it to an array
Args:
path: path to file to load
Returns:
np.ndarray: loaded image data
dict: loaded meta data
"""
img_itk = load_sitk(path, **kwargs)
meta = {key: img_itk.GetMetaData(key) for key in img_itk.GetMetaDataKeys()}
return sitk.GetArrayFromImage(img_itk), meta
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import pickle
import json
import yaml
import time
from contextlib import contextmanager
from itertools import repeat
from multiprocessing.pool import Pool
from collections import OrderedDict
from pathlib import Path
from typing import Sequence, Any, Tuple, Union
from zipfile import BadZipfile
import numpy as np
import SimpleITK as sitk
from loguru import logger
from nndet.io.paths import subfiles, Pathlike
__all__ = ["load_case_cropped", "load_case_from_list",
"load_properties_of_cropped", "npy_dataset",
"load_pickle", "load_json", "save_json", "save_pickle",
"save_yaml", "load_npz_looped",
]
def load_case_from_list(data_files, seg_file=None) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Load data and label of one case from list of paths
Args:
data_files (Sequence[Path]): paths to data files
seg_file (Path): path to segmentation file (if a second file
with a json ending is found, it is treated as an additional
property file and will be loaded automatically)
Returns:
np.ndarary: loaded data (as float32) [C, X, Y, Z]
np.ndarray: loaded segmentation (if no segmentation was provided, None)
(as float32) [1, X, Y, Z]
dict: additional properties of files
`original_size_of_raw_data`: original shape of data (correctly reordered)
`original_spacing`: original spacing (correctly reordered)
`list_of_data_files`: paths of data files
`seg_file`: path to label file
`itk_origin`: origin in world coordinates
`itk_spacing`: spacing in world coordinates
`itk_direction`: direction in world coordinates
"""
assert isinstance(data_files, Sequence), "case must be sequence"
properties = OrderedDict()
data_itk = [sitk.ReadImage(str(f)) for f in data_files]
properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]]
properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]]
properties["list_of_data_files"] = data_files
properties["seg_file"] = seg_file
properties["itk_origin"] = data_itk[0].GetOrigin()
properties["itk_spacing"] = data_itk[0].GetSpacing()
properties["itk_direction"] = data_itk[0].GetDirection()
data_npy = np.stack([sitk.GetArrayFromImage(d) for d in data_itk])
if seg_file is not None:
seg_itk = sitk.ReadImage(str(seg_file))
seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32)
seg_props_file = f"{str(seg_file).split('.')[0]}.json"
if os.path.isfile(seg_props_file):
with open(seg_props_file, "r") as f:
properties.update(json.load(f))
else:
seg_npy = None
return data_npy.astype(np.float32), seg_npy, properties
def load_properties_of_cropped(path: Path):
"""
Load property file of after cropping was performed
(files are name after case id and .pkl ending)
Args:
path (Path): path to file (if .pkl is missing, it will be added automatically)
Returns:
Dict: loaded properties
"""
if not path.suffix == '.pkl':
path = Path(str(path) + '.pkl')
with open(path, 'rb') as f:
properties = pickle.load(f)
return properties
def load_case_cropped(folder: Path, case_id: str) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Load single case after cropping
Args:
folder (Path): path to folder where cases are located
case_id (str): case identifier
Returns:
np.ndarray: data
np.ndarray: segmentation
dict: additional properties
"""
stack = load_npz_looped(os.path.join(folder, case_id) + ".npz",
keys=["data"], num_tries=3,
)["data"]
data = stack[:-1]
seg = stack[-1]
with open(os.path.join(folder, case_id) + ".pkl", "rb") as f:
props = pickle.load(f)
assert data.shape[1:] == seg.shape, (f"Data and segmentation need to have same dim (except first). "
f"Found data {data.shape} and "
f"mask {seg.shape} for case {case_id}")
return data.astype(np.float32), seg.astype(np.int32), props
@contextmanager
def npy_dataset(folder: str, processes: int,
unpack: bool = True, delete_npy: bool = True,
delete_npz: bool = False):
"""
Automatically unpacks the npz dataset and deletes npy data after completion
Args:
folder: path to folder
processes: number of processes to use
unpack: unpack data
delete_npy: delete npy files at the end
delete_npz: delete the npz file after conversion
"""
if unpack:
unpack_dataset(Path(folder), processes, delete_npz=delete_npz)
try:
yield True
finally:
if delete_npy:
del_npy(Path(folder))
def unpack_dataset(folder: Pathlike,
processes: int,
delete_npz: bool = False):
"""
unpacks all npz files in a folder to npy
(whatever you want to have unpacked must be saved under key)
Args
folder: path to folder where data is located
processes: number of processes to use
key: key which should be extracted
delete_npz: delete the npz file after conversion
"""
logger.info("Unpacking dataset")
npz_files = subfiles(Path(folder), identifier="*.npz", join=True)
with Pool(processes) as p:
p.starmap(npz2npy, zip(npz_files, repeat(delete_npz)))
def pack_dataset(folder, processes: int, key: str):
"""
Pack dataset (from npy to npz)
Args
folder: path to folder where data is located
processes: number of processes to use
key: key which should be extracted
"""
logger.info("Packing dataset")
npy_files = subfiles(Path(folder), identifier="*.npy", join=True)
with Pool(processes) as p:
p.starmap(npy2npz, zip(npy_files, repeat(key)))
def npz2npy(npz_file: str, delete_npz: bool = False):
"""
convert npz to npy
Args:
npz_file: path to npz file
delete_npz: delete the npz file after conversion
"""
if not os.path.isfile(npz_file[:-3] + "npy"):
a = load_npz_looped(npz_file, keys=["data", "seg"], num_tries=3)
if a is not None:
np.save(npz_file[:-3] + "npy", a["data"])
np.save(npz_file[:-4] + "_seg.npy", a["seg"])
if delete_npz:
os.remove(npz_file)
def npy2npz(npy_file: str, key: str):
"""
convert npy to npz
Args:
npy_file: path to npy file
key: key to extract
"""
d = np.load(npy_file)
np.savez_compressed(npy_file[:-3] + "npz", **{key: d})
def del_npy(folder: Pathlike):
"""
Deletes all npy files inside folder
"""
npy_files = Path(folder).glob("*.npy")
npy_files = [i for i in npy_files if os.path.isfile(i)]
logger.info(f"Found {len(npy_files)} for removal")
for n in npy_files:
os.remove(n)
def load_json(path: Path, **kwargs) -> Any:
"""
Load json file
Args:
path: path to json file
**kwargs: keyword arguments passed to :func:`json.load`
Returns:
Any: json data
"""
if isinstance(path, str):
path = Path(path)
if not(".json" == path.suffix):
path = str(path) + ".json"
with open(path, "r") as f:
data = json.load(f, **kwargs)
return data
def save_json(data: Any, path: Pathlike, indent: int = 4, **kwargs):
"""
Load json file
Args:
data: data to save to json
path: path to json file
indent: passed to json.dump
**kwargs: keyword arguments passed to :func:`json.dump`
"""
if isinstance(path, str):
path = Path(path)
if not(".json" == path.suffix):
path = Path(str(path) + ".json")
with open(path, "w") as f:
json.dump(data, f, indent=indent, **kwargs)
def load_pickle(path: Path, **kwargs) -> Any:
"""
Load pickle file
Args:
path: path to pickle file
**kwargs: keyword arguments passed to :func:`pickle.load`
Returns:
Any: json data
"""
if isinstance(path, str):
path = Path(path)
if not any([fix == path.suffix for fix in [".pickle", ".pkl"]]):
path = Path(str(path) + ".pkl")
with open(path, "rb") as f:
data = pickle.load(f, **kwargs)
return data
def save_pickle(data: Any, path: Pathlike, **kwargs):
"""
Load pickle file
Args:
data: data to save to pickle
path: path to pickle file
**kwargs: keyword arguments passed to :func:`pickle.dump`
"""
if isinstance(path, str):
path = Path(path)
if not any([fix == path.suffix for fix in [".pickle", ".pkl"]]):
path = str(path) + ".pkl"
with open(str(path), "wb") as f:
data = pickle.dump(data, f, **kwargs)
return data
def save_yaml(data: Any, path: Path, **kwargs):
"""
Load yaml file
Args:
data: data to save to yaml
path: path to yaml file
**kwargs: keyword arguments passed to :func:`yaml.dump`
"""
if isinstance(path, str):
path = Path(path)
if not(".yaml" == path.suffix):
path = str(path) + ".yaml"
with open(path, "w") as f:
yaml.dump(data, f, **kwargs)
def save_txt(data: str, path: Path, **kwargs):
"""
Load yaml file
Args:
data: data to save to txt
path: path to txt file
**kwargs: keyword arguments passed to :func:`json.dump`
"""
if isinstance(path, str):
path = Path(path)
if not(".txt" == path.suffix):
path = str(path) + ".txt"
with open(path, "a") as f:
f.write(str(data))
def load_npz_looped(
p: Pathlike,
keys: Sequence[str],
*args,
num_tries: int = 3,
**kwargs,
) -> Union[np.ndarray, dict]:
"""
Try | Except loop to load numpy files
(especially large numpy files can fail with BadZipFile Errors)
Args:
p: path to file to load
keys: keys to load from npz file
num_tries: number of tries to load file
*args: passed to `np.load`
**kwargs: passed to `np.load`
Returns:
dict: loaded data
"""
if num_tries <= 0:
raise ValueError(f"Num tires needs to be larger than 0, found {num_tries} tries.")
for i in range(num_tries): # try reading the file 3 times
try:
_data = np.load(str(p), *args, **kwargs)
data = {k: _data[k] for k in keys}
break
except Exception as e:
if i == num_tries - 1:
logger.error(f"Could not unpack {p}")
return None
time.sleep(5.)
return data
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import typing
import itertools
import numpy as np
from loguru import logger
from skimage.measure import regionprops
import SimpleITK as sitk
def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
) -> typing.List[tuple]:
"""
Creates indices to crop patches around individual objects in mask
Args
mask: mask where objects have diffrent numbers. Objects need to be numbered
consequtively from one to n, with 0 as background.
cshape: size of individual crops. Needs to be divisible by two.
Otherwise crops do not have the expected size.
If cshape is a int, crops will have the same size in every dimension.
Returns
list[tuple]: each crop generates one tuple with indices
Raises
TypeError: raised if mask and patches define different dimensionalities
TypeError: raised if `cshape` is larger than mask
See Also
:func:`save_get_crop`
Warnings
The returned crops are not checked for image boundaries. Slices
with negative indices and indices which extend over the mask boundaries
are possible! To correct for this, use `save_get_crop` which handles
this exceptions.
"""
if isinstance(cshape, int):
cshape = tuple([cshape] * mask.ndim)
if mask.ndim != len(cshape):
raise TypeError("Size of crops needs to be defined for "
"every dimension")
if any(np.subtract(mask.shape, cshape) < 0):
raise TypeError("Patches must be smaller than data.")
if mask.max() == 0:
# no objects in mask
return []
all_centroids = [i['centroid'] for i in regionprops(mask.astype(np.int32))]
crops = []
for centroid in all_centroids:
crops.append(tuple(slice(int(c) - (s // 2), int(c) + (s // 2))
for c, s in zip(centroid, cshape)))
return crops
def center_crop_object_seg(seg: np.ndarray, cshape: typing.Union[tuple, int],
**kwargs) -> typing.List[tuple]:
"""
Creates indices to crop patches around individual objects in segmentation.
Objects are determined by region growing with connected threshold.
Args
seg: semantic segmentation of objects.
cshape: size of individual crops. Needs to be divisible by two.
Otherwise crops do not have the expected size.
If cshape is a int, crops will have the same size in every dimension.
kwargs: additional keyword arguments passed to `center_crop_objects_mask`
Returns
list[tuple]: each crop generates one tuple with indices
See Also
:func:`save_get_crop`
Warnings
The returned crops are not checked for image boundaries. Slices
with negative indices and indices which extend over the mask boundaries
are possible! To correct for this, use `save_get_crop` which handles
this exceptions.
"""
_mask, _ = create_mask_from_seg(seg)
return center_crop_object_mask(_mask, cshape=cshape, **kwargs)
def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]:
"""
Create a mask where objects are enumerated from 1, ..., n.
Objects are determined by region growing with connected threshold.
Args
seg: semantic segmentation array
Returns
np.ndarray: mask with objects
list: classes to objects (ascending order)
"""
_seg = np.copy(seg).astype(np.int32)
_seg_sitk = sitk.GetImageFromArray(_seg)
_mask = np.zeros_like(seg).astype(np.int32)
_obj_cls = []
_obj = 1
while _seg.max() > 0:
# choose one seed in segmentation
seed = np.transpose(np.nonzero(_seg))[0]
# invert coordinates for sitk
seed_sitk = tuple(seed[:: -1].tolist())
seed = tuple(seed)
# region growing
seg_con = sitk.ConnectedThreshold(_seg_sitk,
seedList=[seed_sitk],
lower=int(_seg[seed]),
upper=int(_seg[seed]))
seg_con = sitk.GetArrayFromImage(seg_con).astype(bool)
# add object to mask
_mask[seg_con] = _obj
_obj_cls.append(_seg[seed])
# remove object from segmentation
_seg[seg_con] = 0
_obj += 1
# objects should never overlap
assert _mask.max() < _obj
return _mask, _obj_cls
def create_grid(cshape: typing.Union[typing.Sequence[int], int],
dshape: typing.Sequence[int],
overlap: typing.Union[typing.Sequence[int], int] = 0,
mode='fixed',
center_boarder: bool = False,
**kwargs,
) -> typing.List[typing.Tuple[slice]]:
"""
Create indices for a grid
Args
cshape: size of individual patches
dshape: shape of data
overlap: overlap between patches. If `overlap` is an integer is is applied
to all dimensions.
mode: defines how borders should be handled, by default 'fixed'.
`fixed` created patches without special handling of borders, thus
the last patch might exceed `dshape`
`symmetric` moves patches such that the the first and last patch are
equally overlapping of dshape (when combined with padding, the last and
first patch would have the same amount of padding)
center_boarder: adds additional crops at the boarders which have the
boarder as their center
Returns
typing.List[slice]: slices to extract patches
Raises
TypeError: raised if `cshape` and `dshape` do not have the same length
TypeError: raised if `overlap` and `dshape` do not have the same length
TypeError: raised if `cshape` is larger than `dshape`
TypeError: raised if `overlap` is larger than `cshape`
Warnings
The returned crops are can exceed the image boundaries. Slices
with negative indices and indices which extend over the image
boundary at the start. To correct for this, use `save_get_crop`
which handles exceptions at borders.
"""
_mode_fn = {
"fixed": _fixed_slices,
"symmetric": _symmetric_slices,
}
if len(dshape) == 3 and len(cshape) == 2:
logger.info("Creating 2d grid.")
slices_3d = dshape[0]
dshape = dshape[1:]
else:
slices_3d = None
# create tuples from shapes
if isinstance(cshape, int):
cshape = tuple([cshape] * len(dshape))
if isinstance(overlap, int):
overlap = tuple([overlap] * len(dshape))
# check shapes
if len(cshape) != len(dshape):
raise TypeError(
"cshape and dshape must be defined for same dimensionality.")
if len(overlap) != len(dshape):
raise TypeError(
"overlap and dshape must be defined for same dimensionality.")
if any(np.subtract(dshape, cshape) < 0):
axes = np.nonzero(np.subtract(dshape, cshape) < 0)
logger.warning(f"Found patch size which is bigger than data: data {dshape} patch {cshape}")
if any(np.subtract(cshape, overlap) < 0):
raise TypeError("Overlap must be smaller than size of patches.")
grid_slices = [_mode_fn[mode](psize, dlim, ov, **kwargs)
for psize, dlim, ov in zip(cshape, dshape, overlap)]
if center_boarder:
for idx, (psize, dlim, ov) in enumerate(zip(cshape, dshape, overlap)):
lower_bound_start = int(-0.5 * psize)
upper_bound_start = dlim - int(0.5 * psize)
grid_slices[idx] = tuple([
slice(lower_bound_start, lower_bound_start + psize),
*grid_slices[idx],
slice(upper_bound_start, upper_bound_start + psize),
])
if slices_3d is not None:
grid_slices = [tuple([slice(i, i + 1) for i in range(slices_3d)])] + grid_slices
grid = list(itertools.product(*grid_slices))
return grid
def _fixed_slices(psize: int, dlim: int, overlap: int, start: int = 0) -> typing.Tuple[slice]:
"""
Creates fixed slicing of a single axis. Only last patch exceeds dlim.
Args
psize: size of patch
dlim: size of data
overlap: overlap between patches
start: where to start patches, by default 0
Returns
typing.List[slice]: ordered slices for a single axis
"""
upper_limit = 0
lower_limit = start
idx = 0
crops = []
while upper_limit < dlim:
if idx != 0:
lower_limit = lower_limit - overlap
upper_limit = lower_limit + psize
crops.append(slice(lower_limit, upper_limit))
lower_limit = upper_limit
idx += 1
return tuple(crops)
def _symmetric_slices(psize: int, dlim: int, overlap: int) -> typing.Tuple[slice]:
"""
Creates symmetric slicing of a single axis. First and last patch exceed
data borders.
Args
psize: size of patch
dlim: size of data
overlap: overlap between patches
start: where to patches, by default 0
Returns
typing.List[slice]: ordered slices for a single axis
"""
if psize >= dlim:
return _fixed_slices(psize, dlim, overlap, start=-(psize - dlim) // 2)
pmod = dlim % (psize - overlap)
start = (pmod - psize) // 2
return _fixed_slices(psize, dlim, overlap, start=start)
def save_get_crop(data: np.ndarray,
crop: typing.Sequence[slice],
mode: str = "shift",
**kwargs,
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
"""
Safely extract crops from data
Args
data: list or tuple with data where patches are extracted from
crop: contains the coordiates of a single crop as slices
mode: Handling of borders when crops are outside of data, by default "shift".
Following modes are supported: "shift" crops are shifted inside the
data | other modes are identical to `np.pad`
kwargs: additional keyword arguments passed to `np.pad`
Returns
list[np.ndarray]: crops from data
Tuple[int]: origin offset of crop with regard to data origin (can be
used to offset bounding boxes)
Tuple[slice]: crop from data used to extract information
See Also
:func:`center_crop_objects_mask`, :func:`center_crop_objects_seg`
Warnings
This functions only supports positive indexing. Negative indices are
interpreted like they were outside the lower boundary!
"""
if len(crop) > data.ndim:
raise TypeError(
"crop must have smaller or same dimensionality as data.")
if mode == 'shift':
# move slices if necessary
return _shifted_crop(data, crop)
else:
# use np.pad if necessary
return _padded_crop(data, crop, mode, **kwargs)
def _shifted_crop(data: np.ndarray,
crop: typing.Sequence[slice],
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
"""
Created shifted crops to handle borders
Args
data: crop is extracted from data
crop: defines boundaries of crops
Returns
List[np.ndarray]: list of crops
Tuple[int]: origin offset of crop with regard to data origin (can be
used to offset bounding boxes)
Tuple[slice]: crop from data used to extract information
Raises
TypeError: raised if patchsize is bigger than data
Warnings
This functions only supports positive indexing. Negative indices are
interpreted like they were outside the lower boundary!
"""
shifted_crop = []
dshape = tuple(data.shape)
# index from back, so batch and channel dimensions must not be defined
axis = data.ndim - len(crop)
for idx, crop_dim in enumerate(crop):
if crop_dim.start < 0:
# start is negative, thus it is subtracted from stop
new_slice = slice(0, crop_dim.stop - crop_dim.start, crop_dim.step)
if new_slice.stop > dshape[axis + idx]:
raise RuntimeError(
"Patch is bigger than entire data. shift "
"is not supported in this case.")
shifted_crop.append(new_slice)
elif crop_dim.stop > dshape[axis + idx]:
new_slice = \
slice(crop_dim.start - (crop_dim.stop - dshape[axis + idx]),
dshape[axis + idx], crop_dim.step)
if new_slice.start < 0:
raise RuntimeError(
"Patch is bigger than entire data. shift "
"is not supported in this case.")
shifted_crop.append(new_slice)
else:
shifted_crop.append(crop_dim)
origin = [int(x.start) for x in shifted_crop]
return data[tuple([..., *shifted_crop])], origin, shifted_crop
def _padded_crop(data: np.ndarray,
crop: typing.Sequence[slice],
mode: str,
**kwargs,
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
"""
Extract patch from data and pad accordingly
Args
data: crop is extracted from data
crop: defines boundaries of crops
mode: mode for padding. See `np.pad` for more details
kwargs: additional keyword arguments passed to :func:`np.pad`
Returns
typing.List[np.ndarray]: list of crops
Tuple[int]: origin offset of crop with regard to data origin (can be
used to offset bounding boxes)
Tuple[slice]: crop from data used to extract information
"""
clipped_crop = []
dshape = tuple(data.shape)
# index from back, so batch and channel dimensions must not be defined
axis = data.ndim - len(crop)
padding = [(0, 0)] * axis if axis > 0 else []
for idx, crop_dim in enumerate(crop):
lower_pad = 0
upper_pad = 0
lower_bound = crop_dim.start
upper_bound = crop_dim.stop
# handle lower bound
if lower_bound < 0:
lower_pad = -lower_bound
lower_bound = 0
# handle upper bound
if upper_bound > dshape[axis + idx]:
upper_pad = upper_bound - dshape[axis + idx]
upper_bound = dshape[axis + idx]
padding.append((lower_pad, upper_pad))
clipped_crop.append(slice(lower_bound, upper_bound, crop_dim.step))
origin = [int(x.start) for x in crop]
return (np.pad(data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs),
origin,
clipped_crop,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
Pathlike = Union[Path, str]
def subfiles(dir_path: Path, identifier: str, join: bool) -> List[str]:
"""
Get all paths
Args:
dir_path: path to directory
join: return dir_path+file_name instead of file_name
identifier: regular expression to select files
Returns:
List[str]: found paths/file names
"""
paths = list(map(str, list(Path(dir_path).glob(identifier))))
if not join:
paths = [p.rsplit(os.path.sep, 1)[-1] for p in paths]
return paths
def get_paths_raw_to_split(data_dir: Path, output_dir: Path,
subdirs: tuple = ("imagesTr", "imagesTs")) -> Tuple[
List[Path], List[Path]]:
"""
Search subdirs for all *.nii.gz files which need to be splitted and
create lists with source and target paths of all files
(target paths retain subfolders inside of output dir)
Args:
data_dir (str): top directory where data is located
output_dir (str): output directory for splitted data
subdirs (Tuple[str]): subdirectories which should be searched for data
Returns:
List[Path]: path to all nii files in subfolders of source directory
List[Path]: path to respective target directory
"""
source_files, target_dirs = [], []
for subdir in subdirs:
sub_output_dir = output_dir / subdir
if not sub_output_dir.is_dir():
sub_output_dir.mkdir(parents=True)
sub_data_dir = data_dir / subdir
nii_files = list(sub_data_dir.glob('*.nii.gz'))
nii_files = list(filter(lambda x: not x.name.startswith('.'), nii_files))
nii_files.sort()
for n in nii_files:
source_files.append(n)
target_dirs.append(sub_output_dir)
return source_files, target_dirs
def get_paths_from_splitted_dir(
num_modalities: int,
splitted_4d_output_dir: Path,
test: bool = False,
labels: bool = True,
remove_ids: Optional[Sequence[str]] = None,
) -> List[List[Path]]:
"""
Create list to all cases (data and label; label is at last position) inside splitted data dir
Args:
num_modalities (int): number of modalities
splitted_4d_output_dir (Path): path to dir where 4d splitted data is located
test: get paths from test data (if False, searches for train data)
labels: add path to labels at last position of each case
remove_ids: case ids which should be removed from the list. If None,
no case ids are removed
Returns:
List[List[Path]]: paths to all splitted files;
each case contains its data files and the label file is at the end
"""
data_subdir = "imagesTs" if test else "imagesTr"
labels_subdir = "labelsTs" if test else "labelsTr"
training_ids = get_case_ids_from_dir(
splitted_4d_output_dir / data_subdir,
remove_modality=True,
)
if remove_ids is not None:
training_ids = [t for t in training_ids if t not in remove_ids]
all_cases = []
for case_id in training_ids:
case_paths = []
for mod in range(num_modalities):
case_paths.append(
splitted_4d_output_dir / data_subdir / f"{case_id}_{mod:04d}.nii.gz")
if labels:
case_paths.append((splitted_4d_output_dir / labels_subdir) / f"{case_id}.nii.gz")
all_cases.append(case_paths)
return all_cases
def get_case_ids_from_dir(dir_path: Path, unique: bool = True,
remove_modality: bool = True, join: bool = False,
pattern="*.nii.gz") -> List[str]:
"""
Get all case ids from a single folder
Args:
dir_path: path to folder
unique: remove all duplicates
remove_modality: remove the modality string from the filename
join: append case ids to directory path
pattern: regular expression used to select files
Returns:
List[str]: all case ids inside the folder
"""
files = map(str, list(Path(dir_path).glob(pattern)))
case_ids = [get_case_id_from_path(f, remove_modality=remove_modality) for f in files]
if unique:
case_ids = list(set(case_ids))
if join:
case_ids = [os.path.join(dir_path, c) for c in case_ids]
return case_ids
def get_case_id_from_path(file_path: Pathlike, remove_modality: bool = True) -> str:
"""
Get case of from path to file
Args:
file_path (str): path to file as string
remove_modality (bool): remove the modality string from the filename
(only used if file ends with .nii.gz)
Returns:
str: case id
"""
file_name = str(file_path).rsplit(os.path.sep, 1)[1]
return get_case_id_from_file(file_name, remove_modality=remove_modality)
def get_case_id_from_file(file_name: str, remove_modality: bool = True) -> str:
"""
Cut of ".nii.gz" from file name
Args:
file_name (str): name of file with .nii.gz ending
remove_modality (bool): remove the modality string from the filename
Returns:
str: name of file without ending
"""
file_name = file_name.split('.')[0]
if remove_modality:
file_name = file_name[:-5]
return file_name
def get_task(task_id: str, name: bool = False, models: bool = False) -> Union[Path, str]:
"""
Resolve task name/dir
Args:
task_id: identifier of task.
E.g. task dir = ../Task12_LIDC
Possible task ids: Task12, LIDC, Task12_LIDC
name: only return the name of the task
models: uses model folder to look for names
Returns:
Union[Path, str]:
path to data task directory if name is False
name of task if name is True
"""
if models:
t = os.getenv("det_models")
else:
t = os.getenv("det_data")
if t is None:
raise ValueError("Framework not configured correctly! "
"Please set `det_data` and `det_models` as environment variables!")
det_data = Path(t)
all_tasks = [d.stem for d in det_data.iterdir() if d.is_dir() and "Task" in d.name]
if task_id.startswith("Task"):
task_id = task_id[4:]
all_tasks = [tn[4:] for tn in all_tasks]
task_options_exact = [d for d in all_tasks if task_id in d]
task_number_id = [tn for tn in all_tasks if tn.split('_', 1)[0] == task_id]
task_name_id = [tn for tn in all_tasks if tn.split('_', 1)[1] == task_id]
if len(task_options_exact) == 1:
result = det_data / f"Task{task_options_exact[0]}"
elif len(task_number_id) == 1:
result = det_data / f"Task{task_number_id[0]}"
elif len(task_name_id) == 1:
result = det_data / f"Task{task_name_id[0]}"
else:
raise ValueError(f"Did not find task id {task_id}."
f"Options are: {all_tasks}")
if name:
result = result.stem
return result
def get_training_dir(model_dir: Pathlike, fold: int) -> Path:
"""
Find training dir from a specific model dir
Args:
model_dir: path to model dir e.g. ../Task12_LIDC/RetinaUNetV0
fold: fold to look for. if -1 look for consolidated dir
Returns:
Path: path to training dir
"""
model_dir = Path(model_dir)
identifier = f"fold{fold}" if fold != -1 else "consolidated"
candidates = [p for p in model_dir.iterdir() if p.is_dir() and identifier in p.stem]
if len(candidates) == 1:
return candidates[0]
else:
raise ValueError(f"Found wrong number of training dirs {candidates} in {model_dir}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment