Commit d4a926ab authored by mibaumgartner's avatar mibaumgartner
Browse files

io

parent dbed904c
from nndet.io.load import load_json, load_pickle, save_json, save_pickle, npy_dataset, save_yaml
from nndet.io.paths import get_case_id_from_file, get_case_id_from_path, \
get_case_ids_from_dir, get_paths_from_splitted_dir, get_paths_raw_to_split, \
get_task, get_training_dir
from typing import Mapping, Type
from nndet.io.augmentation.base import AugmentationSetup
from nndet.utils.registry import Registry
AUGMENTATION_REGISTRY: Mapping[str, Type[AugmentationSetup]] = Registry()
from nndet.io.augmentation.bg_aug import (
NoAug,
DefaultAug,
BaseMoreAug,
MoreAug,
InsaneAug,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence, List
from abc import ABC, abstractmethod
import numpy as np
def get_patch_size(
patch_size: Sequence[int],
rot_x: float,
rot_y: float,
rot_z: float,
scale_range: Sequence[float],
) -> np.ndarray:
"""
Compute enlarged patch size for augmentations to reduce
artifacts at the borders before final cropping
Args:
final_patch_size: target spatial size after final cropping
rot_x: rotation in x in radian
rot_y: rotation in y in radian
rot_z: rotation in z in radian
scale_range: scaling range
Returns:
np.ndarray: enlarged patch size for augmentation
"""
if isinstance(rot_x, (tuple, list)):
rot_x = max(np.abs(rot_x))
if isinstance(rot_y, (tuple, list)):
rot_y = max(np.abs(rot_y))
if isinstance(rot_z, (tuple, list)):
rot_z = max(np.abs(rot_z))
rot_x = min(90 / 360 * 2. * np.pi, rot_x)
rot_y = min(90 / 360 * 2. * np.pi, rot_y)
rot_z = min(90 / 360 * 2. * np.pi, rot_z)
from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
coords = np.array(patch_size)
final_shape = np.copy(coords)
if len(coords) == 3:
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
elif len(coords) == 2:
final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
final_shape /= min(scale_range)
return final_shape.astype(np.int32)
class AugmentationSetup(ABC):
def __init__(self,
patch_size: Sequence[int],
params: dict,
) -> None:
"""
Helper class for augmenation setup
Args:
patch_size: output patch size of augmentations
params: augmentation parameters
Notes:
The needed keys of :attr:`params` depend on the exact
transformations which should be used.
"""
self.patch_size = patch_size
self.params = params
@abstractmethod
def get_training_transforms(self):
"""
Setup training transformations
Needs to be overwritten in subclasses.
"""
raise NotImplementedError
@abstractmethod
def get_validation_transforms(self):
"""
Setup validation transformations
Needs to be overwritten in subclasses.
"""
raise NotImplementedError
def get_patch_size_generator(self) -> List[int]:
"""
Compute patch size to extract from volume to avoid augmentation
artifacts
"""
return list(get_patch_size(
patch_size=self.patch_size,
rot_x=self.params['rotation_x'],
rot_y=self.params['rotation_y'],
rot_z=self.params['rotation_z'],
scale_range=self.params['scale_range'],
))
This diff is collapsed.
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import shutil
import pickle
import numpy as np
from loguru import logger
from multiprocessing.pool import Pool
from pathlib import Path
from typing import List, Tuple, Sequence
from scipy.ndimage import binary_fill_holes
from nndet.io.paths import get_case_id_from_path
from nndet.io.load import load_case_from_list
def create_nonzero_mask(data: np.ndarray) -> np.ndarray:
"""
Create a nonzero mask from data
Args:
data (np.ndarray): input data [C, X, Y, Z]
Returns:
np.ndarray: binary mask on nonzero regions [X, Y, Z]
"""
assert len(data.shape) == 4 or len(data.shape) == 3, \
"data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.max(data != 0, axis=0)
nonzero_mask = binary_fill_holes(nonzero_mask.astype(bool))
return nonzero_mask
def get_bbox_from_mask(mask: np.ndarray, outside_value: int = 0) -> List[Tuple]:
"""
Create a bounding box from a mask
Args:
mask (np.ndarray): mask [X, Y, Z]
outside_value (int): background value
Returns:
np.ndarray: [(dim0_min, dim0_max), (dim1_min, dim1_max), (dim2_min, dim2_max))
"""
mask_voxel_coords = (mask != outside_value).nonzero()
min0idx = int(np.min(mask_voxel_coords[0]))
max0idx = int(np.max(mask_voxel_coords[0])) + 1
min1idx = int(np.min(mask_voxel_coords[1]))
max1idx = int(np.max(mask_voxel_coords[1])) + 1
idx = [(min0idx, max0idx), (min1idx, max1idx)]
if len(mask_voxel_coords) == 3:
min2idx = int(np.min(mask_voxel_coords[2]))
max2idx = int(np.max(mask_voxel_coords[2])) + 1
idx.append((min2idx, max2idx))
return idx
def crop_to_bbox_no_channels(image, bbox: Sequence[Sequence[int]]):
"""
Crops image to bounding box (in spatial dimensions)
Args:
image (arraylike): 2d or 3d array
bbox (Sequence[Sequence[int]]): bounding box coordinated in an interleaved fashion
(e.g. (x1, x2), (y1, y2), (z1, z2))
Returns:
arraylike: cropped array
"""
resizer = tuple([slice(_dim[0], _dim[1]) for _dim in bbox])
return image[resizer]
def crop_to_bbox(data: np.ndarray, bbox: Sequence[Sequence[int]]):
"""
Crops image to bounding box (performed per channel)
Args:
data (np.ndarray): 3d or 4d array [C, X, Y, (Z)]
bbox (Sequence[Sequence[int]]): bounding box coordinated in an interleaved fashion
(e.g. (x1, x2), (y1, y2), (z1, z2))
Returns:
np.ndarray: cropped array
"""
cropped_data = []
for c in range(data.shape[0]):
cropped = crop_to_bbox_no_channels(data[c], bbox)
cropped_data.append(cropped)
data = np.stack(cropped_data)
return data
def crop_to_nonzero(data, seg=None, nonzero_label=-1):
"""
Crop data to nonzero region of data
Args:
data (np.ndarray): data to crop
seg (np.ndarray): segmenation
nonzero_label (int): nonzero label is written into segmentation map
where only background was found
Returns:
np.ndarray: cropped data
np.ndarray: cropped and filled (with nonzero_label) segmentation
List[Tuple[int]]: bounding box of nonzero region
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask, 0)
data = crop_to_bbox(data, bbox)
if seg is not None:
seg = crop_to_bbox(seg, bbox)
nonzero_mask = crop_to_bbox_no_channels(nonzero_mask, bbox)[None]
if seg is not None:
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(np.int32)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask
return data, seg, bbox
class ImageCropper(object):
def __init__(self, num_processes: int, output_dir: Path = None):
"""
Helper class to crop images to non zero region (must hold for all modalities)
In the case of BRaTS and ISLES data this results in a significant reduction in image size
Args:
num_processes (int): number of processes to use for cropping
output_dir (Path): path to output directory
"""
self.output_dir = Path(output_dir) if output_dir is not None else None
self.num_processes = num_processes
self.maybe_init_output_dir()
def maybe_init_output_dir(self):
"""
Create output directory if it does not already exists
"""
if self.output_dir is not None and not self.output_dir.is_dir():
self.output_dir.mkdir()
def run_cropping(self, case_files: List[List[Path]], overwrite_existing: bool = False,
output_dir: Path = None, copy_gt_data: bool = True):
"""
Crops data to non zero region and saves them into output_dir
Optional: also copies ground truth data
Args:
case_files (List[List[Path]]): list with all cases in the structure [Case[Case Files]];
where case files are sorted to corresponding modalities (last file is the label file)
overwrite_existing (bool): overwrite existing crops
output_dir (Path): path to output directory
copy_gt_data (bool): copies ground truth data to output directory
"""
if output_dir is not None:
self.output_dir = Path(output_dir)
self.maybe_init_output_dir()
if copy_gt_data:
self.copy_gt_data(case_files)
list_of_args = []
for _i, case in enumerate(case_files):
case_id = get_case_id_from_path(str(case[0]))
assert not case_id.endswith(".gz") and not case_id.endswith(".nii")
list_of_args.append((case, case_id, overwrite_existing))
with Pool(processes=self.num_processes) as p:
p.map(self._process_data_star, list_of_args)
def copy_gt_data(self, case_files: List[List[Path]]):
"""
Copy ground truth to output directory
"""
output_dir_gt = self.output_dir / "labelsTr"
if not output_dir_gt.is_dir():
output_dir_gt.mkdir()
for j, case in enumerate(case_files):
if case[-1] is not None:
shutil.copy(case[-1], output_dir_gt)
def _process_data_star(self, args):
"""
Unpack argument for function
"""
return self.process_data(*args)
def process_data(self, case: List[Path], case_id: str, overwrite_existing: bool = False):
"""
Extract nonzero region from all cases and create a single array where segmentation
is located in the last channel and save as npz (saved in key `data`)
Additional properties per case are saved inside a pkl file
Args:
case (List[Path]): list of paths to data and label (label is always at the last position
and data is sorted after modalities)
case_id (str): case identifier
overwrite_existing (bool): overwrite existing data
"""
try:
logger.info(f"Processing case {case_id}")
npz_exists = (self.output_dir / f"{case_id}.npz").is_file()
pkl_exists = (self.output_dir / f"{case_id}.pkl").is_file()
if (not npz_exists and not pkl_exists) or overwrite_existing:
data, seg, properties = self.load_crop_from_list_of_files(case[:-1], case[-1])
all_data = np.vstack((data, seg))
np.savez_compressed(self.output_dir / f"{case_id}.npz", data=all_data)
with open(self.output_dir / f"{case_id}.pkl", 'wb') as f:
pickle.dump(properties, f)
else:
logger.warning(f"Case {case_id} already exists and overwrite is deactivated")
except Exception as e:
logger.info(f"exception in: {case_id}: {e}")
raise e
@staticmethod
def load_crop_from_list_of_files(data_files: List[Path], seg_file: Path = None):
"""
Load and crop form list of files
Args:
data_files (List[Path]): paths to data files
seg_file (Path): pth to segmentation
Returns:
np.ndarray: cropped data
np.ndarray: cropped (and filled segmentation: -1 where no forground exists) label
dict: additional properties
`original_size_of_raw_data`: original shape of data (correctly reordered)
`original_spacing`: original spacing (correctly reordered)
`list_of_data_files`: paths of data files
`seg_file`: path to label file
`itk_origin`: origin in world coordinates
`itk_spacing`: spacing in world coordinates
`itk_direction`: direction in world coordinates
`crop_bbox`: List[Tuple[int]] cropped bounding box
`classes`: present classes in segmentation
`size_after_cropping`: size after cropping
"""
data, seg, properties = load_case_from_list(data_files, seg_file)
return ImageCropper.crop(data, properties, seg)
@staticmethod
def crop(data: np.ndarray, properties: dict, seg: np.ndarray = None):
"""
Crop data and segmentation to non zero region
Args:
data (np.ndarray): data to crop [C, X, Y, Z]
properties (dict): additional properties
seg (np.ndarray): segmentation [1, X, Y, Z]
Returns:
data (np.ndarray): data to crop [C, X, Y, Z]
seg (np.ndarray): segmentation [1, X, Y, Z]
properties (dict): newly added properties
`crop_bbox`: List[Tuple[int]] cropped bounding box
`classes`: present classes in segmentation
`size_after_cropping`: size after cropping
"""
shape_before = data.shape
data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=-1)
shape_after = data.shape
# logger.info(f"Shape before crop {shape_before}; after crop {shape_after}; "
# f"spacing {np.array(properties['original_spacing'])}")
properties["crop_bbox"] = bbox
properties['classes'] = np.unique(seg)
seg[seg < -1] = 0
properties["size_after_cropping"] = data[0].shape
return data, seg, properties
from typing import Iterable, Mapping
from nndet.utils.registry import Registry
DATALOADER_REGISTRY: Mapping[str, Iterable] = Registry()
from nndet.io.datamodule.bg_loader import (
DataLoader3DFast,
DataLoader3DBalanced,
DataLoader3DOffset,
DataLoader2DOffset,
DataLoader2DFast,
DataLoader2DDeeplesion,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from pathlib import Path
from collections import OrderedDict
import numpy as np
import pytorch_lightning as pl
from loguru import logger
from sklearn.model_selection import KFold
from nndet.io.utils import load_dataset_id
from nndet.io.load import load_pickle, save_pickle
class BaseModule(pl.LightningDataModule):
def __init__(self,
plan: dict,
augment_cfg: dict,
data_dir: os.PathLike,
fold: int = 0,
**kwargs,
):
"""
Baseclass for nnDetection data nodules.
Overwrite :method:`setup` to customize the bahvior.
The splits are created iniside the init because we
Args:
plan: plan file
augment_cfg: provide settings for augmentation
`splits_file` (str, optional): provide alternative splits file
data_dir: path to preprocessed data dir. Needs to follow:
`.../preprocessed/[data_identifier]/imagesTr
fold: current fold; if None, does not create folds and uses
whole dataset for training and validation (don't do this ...
except you know what you are doing :P)
"""
super().__init__(**kwargs)
self.plan = plan
self.augment_cfg = augment_cfg
self.data_dir = Path(data_dir)
self.fold = fold
self.preprocessed_dir = self.data_dir.parent.parent
self.splits_file = self.augment_cfg.get(
"splits_final", "splits_final.pkl")
self.dataset_tr = {}
self.dataset_val = {}
self.dataset = load_dataset_id(self.data_dir)
self.do_split()
@property
def splits_file(self) -> str:
return self._splits_file
@splits_file.setter
def splits_file(self, f: str) -> None:
if f.endswith("pkl"):
self._splits_file = f
else:
self._splits_file = f + ".pkl"
def do_split(self) -> None:
"""
Load a datasplit.
If not split is found, a new split is created.
Results are saved into :attr:`dataset_tr` and :attr:`dataset_val`
"""
splits_file = self.preprocessed_dir / self.splits_file
if not splits_file.is_file():
self.create_new_split(splits_file)
logger.info(f"Using splits {splits_file} with fold {self.fold}")
splits = load_pickle(splits_file)
if self.fold is None:
logger.warning(f"USING SAME TRAIN AND VAL SET")
tr_keys = val_keys = list(self.dataset.keys())
else:
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def create_new_split(self, splits_file: Path) -> None:
"""
Create a new 5 fold split with a fixed seed
Args:
splits_file: path where splits file should be saved
"""
logger.info("Creating new split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from pathlib import Path
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
from loguru import logger
from batchgenerators.dataloading import SlimDataLoaderBase
from nndet.io.datamodule import DATALOADER_REGISTRY
from nndet.io.load import load_pickle
from nndet.inference.patching import save_get_crop
from nndet.utils.info import maybe_verbose_iterable
from nndet.detection.boxes.utils_np import box_size_np
class FixedSlimDataLoaderBase(SlimDataLoaderBase):
def __init__(self,
*args,
num_batches_per_epoch: int = 2500,
**kwargs,
):
self.num_batches_per_epoch = num_batches_per_epoch
super().__init__(*args, **kwargs)
def __len__(self):
return self.num_batches_per_epoch
@DATALOADER_REGISTRY.register
class DataLoader3DFast(FixedSlimDataLoaderBase):
def __init__(self,
data: Dict,
batch_size: int,
patch_size_generator: Sequence[int],
patch_size_final: Sequence[int],
oversample_foreground_percent: float = 0.5,
memmap_mode: str = "r+",
pad_mode: str = "constant",
pad_kwargs_data: Optional[Dict[str, Any]] = None,
num_batches_per_epoch: int = 2500,
):
"""
Basic Dataloder for 3D Data.
Center of foreground patches is sampled from pre computed bounding
boxes. Background patches are sampled randomly. Cases are selected
randomly.
Args:
data: dict with cases and data paths
batch_size: size of batches to generate
patch_size_generator: patch size prduced by the dataloader
patch_size_final: final patch size after spatial transform
oversample_foreground_percent: Oversample foreground patches.
Each batch will be balanced to fullfill this criterion.
memmap_mode: Do not change this. Defaults to "r".
pad_mode: Padding mode for data. Defaults to "constant".
pad_kwargs_data: Addition kwargs for data padding. Defaults to None.
Raises:
ValueError: patch size of dataloder and final patch size need to
have the same length
"""
super().__init__(
data=data,
batch_size=batch_size,
number_of_threads_in_multithreaded=None,
num_batches_per_epoch=num_batches_per_epoch,
)
if len(patch_size_generator) != len(patch_size_final):
raise ValueError(f"Final and generator patch size need to have the same length."
f"Found generator {patch_size_generator} and "
f"final {patch_size_final} patch size.")
self.patch_size_generator = patch_size_generator
self.patch_size_final = patch_size_final
self.oversample_foreground_percent = oversample_foreground_percent
self.memmap_mode = memmap_mode
self.pad_mode = pad_mode
self.pad_kwargs_data = pad_kwargs_data if pad_kwargs_data is not None else {}
# we sample bigger patches and create a center crop during augmentation
# to cover the boarders of the patient we need to adjust the position
self.need_to_pad = (np.array(patch_size_generator) - np.array(patch_size_final)).astype(np.int32)
self.data_shape_batch, self.seg_shape_batch = self.determine_shapes()
self.cache = self.build_cache()
self.candidates_key = "boxes_file"
def determine_shapes(self) -> Tuple[Tuple[int], Tuple[int]]:
"""
Determines data and segmentation shape to preallocate arrays
during loading
Raises:
RuntimeError: Raised if data was not unpacked
Returns:
Tuple[Tuple[int], Tuple[int]]: Final shape of data,
Final shape of seg (including batchdim)
"""
k = list(self._data.keys())[0]
if (p := Path(self._data[k]['data_file'])).is_file():
data = np.load(str(p), self.memmap_mode, allow_pickle=False)
else:
raise RuntimeError("You shall not pass! Unpack data first!")
if (p := Path(self._data[k]['seg_file'])).is_file():
seg = np.load(str(p), self.memmap_mode, allow_pickle=False)
else:
raise RuntimeError("You shall not pass! Unpack data first!")
num_data_channels = data.shape[0]
num_seg_channels = seg.shape[0]
data_shape = (self.batch_size, num_data_channels, *self.patch_size_generator)
seg_shape = (self.batch_size, num_seg_channels, *self.patch_size_generator)
return data_shape, seg_shape
def build_cache(self) -> Dict[str, List]:
"""
Build up cache for sampling
Returns:
Dict[str, List]: cache for sampling
`case`: list with all case identifiers
`instances`: list with tuple of (case_id, instance_id)
"""
instance_cache = []
logger.info("Building Sampling Cache for Dataloder")
for case_id, item in maybe_verbose_iterable(self._data.items(), desc="Sampling Cache"):
instances = load_pickle(item['boxes_file'])["instances"]
if instances:
for instance_id in instances:
instance_cache.append((case_id, instance_id))
return {"case": list(self._data.keys()), "instances": instance_cache}
def select(self) -> Tuple[List, List]:
"""
Selects cases and instances. If instance id is -1 a random background
patch will be sampled.
Foreground sampling: sample uniformly from all the foreground classes
and enforce the respective class while patch sampling.
Background sampling: We jsut sample a random case
Returns:
List: case identifiers
List: instance ids
id > 0 indicates an instance
id = -1 indicates a random (background) patch
"""
selected_cases = []
selected_instances = []
for idx in range(self.batch_size):
if idx < round(self.batch_size * (1 - self.oversample_foreground_percent)):
# sample bg / random case
selected_cases.append(np.random.choice(self.cache["case"]))
selected_instances.append(-1)
else:
# sample fg / select an instance
idx = np.random.choice(range(len(self.cache["instances"])))
_case, _instance_id = self.cache["instances"][idx]
selected_cases.append(_case)
selected_instances.append(int(_instance_id))
return selected_cases, selected_instances
def generate_train_batch(self) -> Dict[str, Any]:
"""
Generate a single batch
Returns:
Dict: batch dict
`data` (np.ndarray): data
`seg` (np.ndarray): unordered(!) instance segmentation
Reordering needs to happen after final crop
`instances` (List[Sequence[int]]): class for each instance in
the case (<- we can not extract them because we do not
know the present instances yet)
`properties`(List[Dict]): properties of each case
`keys` (List[str]): case ids
"""
data_batch = np.zeros(self.data_shape_batch, dtype=float)
seg_batch = np.zeros(self.seg_shape_batch, dtype=float)
instances_batch, properties_batch, case_ids_batch = [], [], []
selected_cases, selected_instances = self.select()
for batch_idx, (case_id, instance_id) in enumerate(zip(selected_cases, selected_instances)):
# print(case_id, instance_id)
case_data = np.load(self._data[case_id]['data_file'], self.memmap_mode, allow_pickle=True)
case_seg = np.load(self._data[case_id]['seg_file'], self.memmap_mode, allow_pickle=True)
properties = load_pickle(self._data[case_id]['properties_file'])
if instance_id < 0:
candidates = self.load_candidates(case_id=case_id, fg_crop=False)
crop = self.get_bg_crop(
case_data=case_data,
case_seg=case_seg,
properties=properties,
case_id=case_id,
candidates=candidates,
)
else:
candidates = self.load_candidates(case_id=case_id, fg_crop=True)
crop = self.get_fg_crop(
case_data=case_data,
case_seg=case_seg,
properties=properties,
case_id=case_id,
instance_id=instance_id,
candidates=candidates,
)
data_batch[batch_idx] = save_get_crop(case_data,
crop=crop,
mode=self.pad_mode,
**self.pad_kwargs_data,
)[0]
seg_batch[batch_idx] = save_get_crop(case_seg,
crop=crop,
mode='constant',
constant_values=-1,
)[0]
case_ids_batch.append(case_id)
instances_batch.append(properties.pop("instances"))
properties_batch.append(properties)
return {'data': data_batch,
'seg': seg_batch,
'properties': properties_batch,
'instance_mapping': instances_batch,
'keys': case_ids_batch,
}
def load_candidates(self, case_id: str, fg_crop: bool) -> Union[Dict, None]:
"""
Load candidates for sampling
Args:
case_id: case id to load candidates from
fg_crop: True if foreground crop will be sampled, False if
background will be sampled
Returns:
Union[Dict, None]: dict if fg, None if bg
"""
if fg_crop:
return load_pickle(self._data[case_id]['boxes_file'])
else:
return None
def get_fg_crop(self,
case_data: np.ndarray,
case_seg: np.ndarray,
properties: dict,
case_id: str,
instance_id: int,
candidates: Union[Dict, None],
) -> List[slice]:
"""
Sample foreground patches from precomputed boxes
Args:
case_data: case data (this should be a memmap!)
case_seg: case segmentation (this should be a memmap!)
properties: properties of case
case_id: identifier of case
instance_id: instance index to sample
candidates: candidate positions to sample foreground from.
Should not be None for this case.
Returns:
List[slice]: determined crop
"""
assert candidates is not None
# some instances might get lost during resampling so we need to find the correct index
idx = candidates["instances"].index(instance_id)
box = candidates["boxes"][idx] # [6]
origin0 = np.random.randint(int(box[0]) + 1, int(box[2])) - (self.patch_size_generator[0] // 2)
origin1 = np.random.randint(int(box[1]) + 1, int(box[3])) - (self.patch_size_generator[1] // 2)
origin2 = np.random.randint(int(box[4]) + 1, int(box[5])) - (self.patch_size_generator[2] // 2)
return [slice(origin0, origin0 + self.patch_size_generator[0]),
slice(origin1, origin1 + self.patch_size_generator[1]),
slice(origin2, origin2 + self.patch_size_generator[2])]
def get_bg_crop(self,
case_data: np.ndarray,
case_seg: np.ndarray,
properties: dict,
case_id: str,
candidates: Union[Dict, None],
) -> List[slice]:
"""
Extract slices for (random) background crop
Args:
case_data: case data (this should be a memmap!)
case_seg: case segmentation (this should be a memmap!)
properties: properties of case
case_id: identifier of case
candidates: foreground candidates. Is not used in this
specific implementation and thus None
Returns:
List[slice]: determined crop
"""
data_shape = case_data.shape[1:]
crop = []
for ps, ds, _pad in zip(self.patch_size_generator, data_shape, self.need_to_pad):
pad = _pad
if pad + ds < ps:
pad = ps - ds
origin = np.random.randint(-(pad // 2), ds + (pad // 2) + (pad % 2) - ps + 1)
crop.append(slice(origin, origin + ps))
return crop
@DATALOADER_REGISTRY.register
class DataLoader3DOffset(DataLoader3DFast):
def get_fg_crop(self,
case_data: np.ndarray,
case_seg: np.ndarray,
properties: dict,
case_id: str,
instance_id: int,
candidates: Union[Dict, None],
) -> List[slice]:
"""
Sample foreground patches from precomputed boxes
Args:
case_data: case data (this should be a memmap!)
case_seg: case segmentation (this should be a memmap!)
properties: properties of case
case_id: identifier of case
instance_id: instance index to sample
candidates: candidate positions to sample foreground from.
Should not be None for this case.
Returns:
List[slice]: determined crop
"""
spatial_shape = case_data.shape[1:]
# some instances might get lost during resampling so we need to find the correct index
idx = candidates["instances"].index(instance_id)
box = candidates["boxes"][[idx]] # [1, 6]
box_size = box_size_np(box)[0]
box = box[0]
origins = []
for i, (ib, ib2) in enumerate([(0, 2), (1, 3), (4, 5)]):
if spatial_shape[i] <= self.patch_size_generator[i]: # patch larger than scan
# we center the slice and pad the rest
origins.append(- (self.need_to_pad[i] // 2))
elif box_size[i] >= self.patch_size_final[i]: # selected instance is larger than patch
# we can not offset, we select our center point inside the bounding box and hope for the best
center = np.random.randint(int(box[ib]) + 1, int(box[ib2]))
origins.append(center - (self.patch_size_generator[0] // 2))
else: # create best effort offset
patch_upper_bound = spatial_shape[i] - self.patch_size_final[i]
lower_bound = np.clip(box[ib] - (self.patch_size_final[i] - box_size[i]),
a_min=0, a_max=patch_upper_bound)
upper_bound = np.clip(box[ib], a_min=0, a_max=patch_upper_bound)
if lower_bound == upper_bound:
_origin = int(lower_bound)
else:
_origin = np.random.randint(lower_bound, upper_bound)
origins.append(_origin - (self.need_to_pad[i] // 2))
return [slice(origins[0], origins[0] + self.patch_size_generator[0]),
slice(origins[1], origins[1] + self.patch_size_generator[1]),
slice(origins[2], origins[2] + self.patch_size_generator[2]),
]
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from typing import Iterable, Optional, List, Sequence, Type
import numpy as np
from loguru import logger
from batchgenerators.dataloading import (
MultiThreadedAugmenter,
SingleThreadedAugmenter,
)
from nndet.io.augmentation import AUGMENTATION_REGISTRY
from nndet.io.datamodule import DATALOADER_REGISTRY
from nndet.io.augmentation.base import AugmentationSetup
from nndet.io.datamodule.base import BaseModule
class FixedLengthMultiThreadedAugmenter(MultiThreadedAugmenter):
def __len__(self):
return len(self.generator)
class FixedLengthSingleThreadedAugmenter(SingleThreadedAugmenter):
def __len__(self):
return len(self.data_loader)
def get_augmenter(dataloader,
transform,
num_processes: int,
num_cached_per_queue: int = 2,
multiprocessing: bool = True,
seeds: Optional[List[int]] = None,
pin_memory=True,
**kwargs,
):
"""
Wrapper to switch between multi-threaded and single-threaded augmenter
"""
if multiprocessing:
logger.info(f"Using {num_processes} num_processes "
f"and {num_cached_per_queue} num_cached_per_queue for augmentation.")
loader = FixedLengthMultiThreadedAugmenter(
data_loader=dataloader,
transform=transform,
num_processes=num_processes,
num_cached_per_queue=num_cached_per_queue,
seeds=seeds,
pin_memory=pin_memory,
**kwargs,
)
else:
loader = FixedLengthSingleThreadedAugmenter(
data_loader=dataloader,
transform=transform,
**kwargs,
)
return loader
class Datamodule(BaseModule):
def __init__(self,
plan: dict,
augment_cfg: dict,
data_dir: os.PathLike,
fold: int = 0,
**kwargs,
):
"""
Batchgenerator based datamodule
Args:
augment_cfg: provide settings for augmentation
`splits_file` (str, optional): provide alternative splits file
`oversample_foreground_percent` (float, optional):
ratio of foreground and background inside of batches,
defaults to 0.33
`patch_size`(Sequence[int], optional): overwrite patch size
`batch_size`(int, optional): overwrite patch size
plan: current plan
preprocessed_dir: path to base preprocessed dir
data_dir: path to preprocessed data dir
fold: current fold; if None, does not create folds and uses
whole dataset for training and validation (don't do this ...
except you know what you are doing :P)
"""
super().__init__(
plan=plan,
augment_cfg=augment_cfg,
data_dir=data_dir,
fold=fold,
**kwargs,
)
self.augmentation: Optional[Type[AugmentationSetup]] = None
self.patch_size_generator: Optional[Sequence[int]] = None
@property
def patch_size(self):
"""
Get patch size which can be (optionally) overwritten in the
augmentation config
"""
if "patch_size" in self.augment_cfg:
ps = self.augment_cfg["patch_size"]
logger.warning(f"Patch Size Overwrite Found: running patch size {ps}")
return np.array(ps).astype(np.int32)
else:
return np.array(self.plan['patch_size']).astype(np.int32)
@property
def batch_size(self):
"""
Get batch size which can be (optionally) overwritten in the
augmentation config
"""
if "batch_size" in self.augment_cfg:
bs = self.augment_cfg["batch_size"]
logger.warning(f"Batch Size Overwrite Found: running batch size {bs}")
return bs
else:
return self.plan["batch_size"]
@property
def dataloader(self):
"""
Get dataloader class name
"""
return self.augment_cfg['dataloader'].format(self.plan["network_dim"])
@property
def dataloader_kwargs(self):
"""
Get dataloader kwargs which can be (optionally) overwritten in the
augmentation config
"""
dataloader_kwargs = self.plan.get('dataloader_kwargs', {})
if dl_kwargs := self.augment_cfg.get("dataloader_kwargs", {}):
logger.warning(f"Dataloader Kwargs Overwrite Found: {dl_kwargs}")
dataloader_kwargs.update(dl_kwargs)
return dataloader_kwargs
def setup(self, stage: Optional[str] = None):
"""
Process augmentation configurations and plan to determine the
patch size, the patch size for the generator and create the
augmentation object.
"""
dim = len(self.patch_size)
params = self.augment_cfg["augmentation"]
patch_size = self.patch_size
if dim == 2:
logger.info("Using 2D augmentation params")
overwrites_2d = params.get("2d_overwrites", {})
params.update(overwrites_2d)
elif dim == 3 and self.plan['do_dummy_2D_data_aug']:
logger.info("Using dummy 2d augmentation params")
params["dummy_2D"] = True
params["elastic_deform_alpha"] = params["2d_overwrites"]["elastic_deform_alpha"]
params["elastic_deform_sigma"] = params["2d_overwrites"]["elastic_deform_sigma"]
params["rotation_x"] = params["2d_overwrites"]["rotation_x"]
params["selected_seg_channels"] = [0]
params["use_mask_for_norm"] = self.plan['use_mask_for_norm']
params["rotation_x"] = [i / 180 * np.pi for i in params["rotation_x"]]
params["rotation_y"] = [i / 180 * np.pi for i in params["rotation_y"]]
params["rotation_z"] = [i / 180 * np.pi for i in params["rotation_z"]]
augmentation_cls = AUGMENTATION_REGISTRY[params["transforms"]]
self.augmentation = augmentation_cls(
patch_size=patch_size,
params=params,
)
self.patch_size_generator = self.augmentation.get_patch_size_generator()
logger.info(f"Augmentation: {params['transforms']} transforms and "
f"{params.get('name', 'no_name')} params ")
logger.info(f"Loading network patch size {self.augmentation.patch_size} "
f"and generator patch size {self.patch_size_generator}")
def train_dataloader(self) -> Iterable:
"""
Create training dataloader
Returns:
Iterable: dataloader for training
"""
dataloader_cls = DATALOADER_REGISTRY.get(self.dataloader)
logger.info(f"Using training {self.dataloader} with {self.dataloader_kwargs}")
dl_tr = dataloader_cls(
data=self.dataset_tr,
batch_size=self.batch_size,
patch_size_generator=self.patch_size_generator,
patch_size_final=self.patch_size,
oversample_foreground_percent=self.augment_cfg[
"oversample_foreground_percent"],
pad_mode="constant",
num_batches_per_epoch=self.augment_cfg[
"num_train_batches_per_epoch"],
**self.dataloader_kwargs,
)
tr_gen = get_augmenter(
dataloader=dl_tr,
transform=self.augmentation.get_training_transforms(),
num_processes=min(int(self.augment_cfg.get('num_threads', 12)), 16) - 1,
num_cached_per_queue=self.augment_cfg.get('num_cached_per_thread', 2),
multiprocessing=self.augment_cfg.get("multiprocessing", True),
seeds=None,
pin_memory=True,
)
logger.info("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())))
return tr_gen
def val_dataloader(self):
"""
Create validation dataloader
Returns:
Iterable: dataloader for validation
"""
dataloader_cls = DATALOADER_REGISTRY.get(self.dataloader)
logger.info(f"Using validation {self.dataloader} with {self.dataloader_kwargs}")
dl_val = dataloader_cls(
data=self.dataset_val,
batch_size=self.batch_size,
patch_size_generator=self.patch_size,
patch_size_final=self.patch_size,
oversample_foreground_percent=self.augment_cfg[
"oversample_foreground_percent"],
pad_mode="constant",
num_batches_per_epoch=self.augment_cfg[
"num_val_batches_per_epoch"],
**self.dataloader_kwargs,
)
val_gen = get_augmenter(
dataloader=dl_val,
transform=self.augmentation.get_validation_transforms(),
num_processes=min(int(self.augment_cfg.get('num_threads', 12)), 16) - 1,
num_cached_per_queue=self.augment_cfg.get('num_cached_per_thread', 2),
multiprocessing=self.augment_cfg.get("multiprocessing", True),
seeds=None,
pin_memory=True,
)
logger.info("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())))
return val_gen
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import pickle
import json
import yaml
import time
from contextlib import contextmanager
from itertools import repeat
from multiprocessing.pool import Pool
from collections import OrderedDict
from pathlib import Path
from typing import Sequence, Any, Tuple, Union
from zipfile import BadZipfile
import numpy as np
import SimpleITK as sitk
from loguru import logger
from nndet.io.paths import subfiles, Pathlike
__all__ = ["load_case_cropped", "load_case_from_list",
"load_properties_of_cropped", "npy_dataset",
"load_pickle", "load_json", "save_json", "save_pickle",
"save_yaml", "load_npz_looped",
]
def load_case_from_list(data_files, seg_file=None) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Load data and label of one case from list of paths
Args:
data_files (Sequence[Path]): paths to data files
seg_file (Path): path to segmentation file (if a second file
with a json ending is found, it is treated as an additional
property file and will be loaded automatically)
Returns:
np.ndarary: loaded data (as float32) [C, X, Y, Z]
np.ndarray: loaded segmentation (if no segmentation was provided, None)
(as float32) [1, X, Y, Z]
dict: additional properties of files
`original_size_of_raw_data`: original shape of data (correctly reordered)
`original_spacing`: original spacing (correctly reordered)
`list_of_data_files`: paths of data files
`seg_file`: path to label file
`itk_origin`: origin in world coordinates
`itk_spacing`: spacing in world coordinates
`itk_direction`: direction in world coordinates
"""
assert isinstance(data_files, Sequence), "case must be sequence"
properties = OrderedDict()
data_itk = [sitk.ReadImage(str(f)) for f in data_files]
properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]]
properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]]
properties["list_of_data_files"] = data_files
properties["seg_file"] = seg_file
properties["itk_origin"] = data_itk[0].GetOrigin()
properties["itk_spacing"] = data_itk[0].GetSpacing()
properties["itk_direction"] = data_itk[0].GetDirection()
data_npy = np.stack([sitk.GetArrayFromImage(d) for d in data_itk])
if seg_file is not None:
seg_itk = sitk.ReadImage(str(seg_file))
seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32)
seg_props_file = f"{str(seg_file).split('.')[0]}.json"
if os.path.isfile(seg_props_file):
with open(seg_props_file, "r") as f:
properties.update(json.load(f))
else:
seg_npy = None
return data_npy.astype(np.float32), seg_npy, properties
def load_properties_of_cropped(path: Path):
"""
Load property file of after cropping was performed
(files are name after case id and .pkl ending)
Args:
path (Path): path to file (if .pkl is missing, it will be added automatically)
Returns:
Dict: loaded properties
"""
if not path.suffix == '.pkl':
path = Path(str(path) + '.pkl')
with open(path, 'rb') as f:
properties = pickle.load(f)
return properties
def load_case_cropped(folder: Path, case_id: str) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Load single case after cropping
Args:
folder (Path): path to folder where cases are located
case_id (str): case identifier
Returns:
np.ndarray: data
np.ndarray: segmentation
dict: additional properties
"""
stack = load_npz_looped(os.path.join(folder, case_id) + ".npz",
keys=["data"], num_tries=3,
)["data"]
data = stack[:-1]
seg = stack[-1]
with open(os.path.join(folder, case_id) + ".pkl", "rb") as f:
props = pickle.load(f)
assert data.shape[1:] == seg.shape, (f"Data and segmentation need to have same dim (except first). "
f"Found data {data.shape} and "
f"mask {seg.shape} for case {case_id}")
return data.astype(np.float32), seg.astype(np.int32), props
@contextmanager
def npy_dataset(folder: str, processes: int,
unpack: bool = True, delete_npy: bool = True,
delete_npz: bool = False):
"""
Automatically unpacks the npz dataset and deletes npy data after completion
Args:
folder: path to folder
processes: number of processes to use
unpack: unpack data
delete_npy: delete npy files at the end
delete_npz: delete the npz file after conversion
"""
if unpack:
unpack_dataset(Path(folder), processes, delete_npz=delete_npz)
try:
yield True
finally:
if delete_npy:
del_npy(Path(folder))
def unpack_dataset(folder: Pathlike,
processes: int,
delete_npz: bool = False):
"""
unpacks all npz files in a folder to npy
(whatever you want to have unpacked must be saved under key)
Args
folder: path to folder where data is located
processes: number of processes to use
key: key which should be extracted
delete_npz: delete the npz file after conversion
"""
logger.info("Unpacking dataset")
npz_files = subfiles(Path(folder), identifier="*.npz", join=True)
with Pool(processes) as p:
p.starmap(npz2npy, zip(npz_files, repeat(delete_npz)))
def pack_dataset(folder, processes: int, key: str):
"""
Pack dataset (from npy to npz)
Args
folder: path to folder where data is located
processes: number of processes to use
key: key which should be extracted
"""
logger.info("Packing dataset")
npy_files = subfiles(Path(folder), identifier="*.npy", join=True)
with Pool(processes) as p:
p.starmap(npy2npz, zip(npy_files, repeat(key)))
def npz2npy(npz_file: str, delete_npz: bool = False):
"""
convert npz to npy
Args:
npz_file: path to npz file
delete_npz: delete the npz file after conversion
"""
if not os.path.isfile(npz_file[:-3] + "npy"):
a = load_npz_looped(npz_file, keys=["data", "seg"], num_tries=3)
if a is not None:
np.save(npz_file[:-3] + "npy", a["data"])
np.save(npz_file[:-4] + "_seg.npy", a["seg"].astype(np.int16))
if delete_npz:
os.remove(npz_file)
def npy2npz(npy_file: str, key: str):
"""
convert npy to npz
Args:
npy_file: path to npy file
key: key to extract
"""
d = np.load(npy_file)
np.savez_compressed(npy_file[:-3] + "npz", **{key: d})
def del_npy(folder: Pathlike):
"""
Deletes all npy files inside folder
"""
npy_files = Path(folder).glob("*.npy")
npy_files = [i for i in npy_files if os.path.isfile(i)]
logger.info(f"Found {len(npy_files)} for removal")
for n in npy_files:
os.remove(n)
def load_json(path: Path, **kwargs) -> Any:
"""
Load json file
Args:
path: path to json file
**kwargs: keyword arguments passed to :func:`json.load`
Returns:
Any: json data
"""
if isinstance(path, str):
path = Path(path)
if not(".json" == path.suffix):
path = str(path) + ".json"
with open(path, "r") as f:
data = json.load(f, **kwargs)
return data
def save_json(data: Any, path: Pathlike, indent: int = 4, **kwargs):
"""
Load json file
Args:
data: data to save to json
path: path to json file
indent: passed to json.dump
**kwargs: keyword arguments passed to :func:`json.dump`
"""
if isinstance(path, str):
path = Path(path)
if not(".json" == path.suffix):
path = Path(str(path) + ".json")
with open(path, "w") as f:
json.dump(data, f, indent=indent, **kwargs)
def load_pickle(path: Path, **kwargs) -> Any:
"""
Load pickle file
Args:
path: path to pickle file
**kwargs: keyword arguments passed to :func:`pickle.load`
Returns:
Any: json data
"""
if isinstance(path, str):
path = Path(path)
if not any([fix == path.suffix for fix in [".pickle", ".pkl"]]):
path = Path(str(path) + ".pkl")
with open(path, "rb") as f:
data = pickle.load(f, **kwargs)
return data
def save_pickle(data: Any, path: Pathlike, **kwargs):
"""
Load pickle file
Args:
data: data to save to pickle
path: path to pickle file
**kwargs: keyword arguments passed to :func:`pickle.dump`
"""
if isinstance(path, str):
path = Path(path)
if not any([fix == path.suffix for fix in [".pickle", ".pkl"]]):
path = str(path) + ".pkl"
with open(str(path), "wb") as f:
data = pickle.dump(data, f, **kwargs)
return data
def save_yaml(data: Any, path: Path, **kwargs):
"""
Load yaml file
Args:
data: data to save to yaml
path: path to yaml file
**kwargs: keyword arguments passed to :func:`yaml.dump`
"""
if isinstance(path, str):
path = Path(path)
if not(".yaml" == path.suffix):
path = str(path) + ".yaml"
with open(path, "w") as f:
yaml.dump(data, f, **kwargs)
def save_txt(data: str, path: Path, **kwargs):
"""
Load yaml file
Args:
data: data to save to txt
path: path to txt file
**kwargs: keyword arguments passed to :func:`json.dump`
"""
if isinstance(path, str):
path = Path(path)
if not(".txt" == path.suffix):
path = str(path) + ".txt"
with open(path, "a") as f:
f.write(str(data))
def load_npz_looped(
p: Pathlike,
keys: Sequence[str],
*args,
num_tries: int = 3,
**kwargs,
) -> Union[np.ndarray, dict]:
"""
Try | Except loop to load numpy files
(especially large numpy files can fail with BadZipFile Errors)
Args:
p: path to file to load
keys: keys to load from npz file
num_tries: number of tries to load file
*args: passed to `np.load`
**kwargs: passed to `np.load`
Returns:
dict: loaded data
"""
if num_tries <= 0:
raise ValueError(f"Num tires needs to be larger than 0, found {num_tries} tries.")
for i in range(num_tries): # try reading the file 3 times
try:
_data = np.load(str(p), *args, **kwargs)
data = {k: _data[k] for k in keys}
break
except Exception as e:
if i == num_tries - 1:
logger.error(f"Could not unpack {p}")
return None
time.sleep(5.)
return data
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
Pathlike = Union[Path, str]
def subfiles(dir_path: Path, identifier: str, join: bool) -> List[str]:
"""
Get all paths
Args:
dir_path: path to directory
join: return dir_path+file_name instead of file_name
identifier: regular expression to select files
Returns:
List[str]: found paths/file names
"""
paths = list(map(str, list(Path(dir_path).glob(identifier))))
if not join:
paths = [p.rsplit(os.path.sep, 1)[-1] for p in paths]
return paths
def get_paths_raw_to_split(data_dir: Path, output_dir: Path,
subdirs: tuple = ("imagesTr", "imagesTs")) -> Tuple[
List[Path], List[Path]]:
"""
Search subdirs for all *.nii.gz files which need to be splitted and
create lists with source and target paths of all files
(target paths retain subfolders inside of output dir)
Args:
data_dir (str): top directory where data is located
output_dir (str): output directory for splitted data
subdirs (Tuple[str]): subdirectories which should be searched for data
Returns:
List[Path]: path to all nii files in subfolders of source directory
List[Path]: path to respective target directory
"""
source_files, target_dirs = [], []
for subdir in subdirs:
sub_output_dir = output_dir / subdir
if not sub_output_dir.is_dir():
sub_output_dir.mkdir(parents=True)
sub_data_dir = data_dir / subdir
nii_files = list(sub_data_dir.glob('*.nii.gz'))
nii_files = list(filter(lambda x: not x.name.startswith('.'), nii_files))
nii_files.sort()
for n in nii_files:
source_files.append(n)
target_dirs.append(sub_output_dir)
return source_files, target_dirs
def get_paths_from_splitted_dir(
num_modalities: int,
splitted_4d_output_dir: Path,
test: bool = False,
labels: bool = True,
remove_ids: Optional[Sequence[str]] = None,
) -> List[List[Path]]:
"""
Create list to all cases (data and label; label is at last position) inside splitted data dir
Args:
num_modalities (int): number of modalities
splitted_4d_output_dir (Path): path to dir where 4d splitted data is located
test: get paths from test data (if False, searches for train data)
labels: add path to labels at last position of each case
remove_ids: case ids which should be removed from the list. If None,
no case ids are removed
Returns:
List[List[Path]]: paths to all splitted files;
each case contains its data files and the label file is at the end
"""
data_subdir = "imagesTs" if test else "imagesTr"
labels_subdir = "labelsTs" if test else "labelsTr"
training_ids = get_case_ids_from_dir(
splitted_4d_output_dir / data_subdir,
remove_modality=True,
)
if remove_ids is not None:
training_ids = [t for t in training_ids if t not in remove_ids]
all_cases = []
for case_id in training_ids:
case_paths = []
for mod in range(num_modalities):
case_paths.append(
splitted_4d_output_dir / data_subdir / f"{case_id}_{mod:04d}.nii.gz")
if labels:
case_paths.append((splitted_4d_output_dir / labels_subdir) / f"{case_id}.nii.gz")
all_cases.append(case_paths)
return all_cases
def get_case_ids_from_dir(dir_path: Path, unique: bool = True,
remove_modality: bool = True, join: bool = False,
pattern="*.nii.gz") -> List[str]:
"""
Get all case ids from a single folder
Args:
dir_path: path to folder
unique: remove all duplicates
remove_modality: remove the modality string from the filename
join: append case ids to directory path
pattern: regular expression used to select files
Returns:
List[str]: all case ids inside the folder
"""
files = map(str, list(Path(dir_path).glob(pattern)))
case_ids = [get_case_id_from_path(f, remove_modality=remove_modality) for f in files]
if unique:
case_ids = list(set(case_ids))
if join:
case_ids = [os.path.join(dir_path, c) for c in case_ids]
return case_ids
def get_case_id_from_path(file_path: Pathlike, remove_modality: bool = True) -> str:
"""
Get case of from path to file
Args:
file_path (str): path to file as string
remove_modality (bool): remove the modality string from the filename
(only used if file ends with .nii.gz)
Returns:
str: case id
"""
file_name = str(file_path).rsplit(os.path.sep, 1)[1]
return get_case_id_from_file(file_name, remove_modality=remove_modality)
def get_case_id_from_file(file_name: str, remove_modality: bool = True) -> str:
"""
Cut of ".nii.gz" from file name
Args:
file_name (str): name of file with .nii.gz ending
remove_modality (bool): remove the modality string from the filename
Returns:
str: name of file without ending
"""
file_name = file_name.split('.')[0]
if remove_modality:
file_name = file_name[:-5]
return file_name
def get_task(task_id: str, name: bool = False, models: bool = False) -> Union[Path, str]:
"""
Resolve task name/dir
Args:
task_id: identifier of task.
E.g. task dir = ../Task12_LIDC
Possible task ids: Task12, LIDC, Task12_LIDC
name: only return the name of the task
models: uses model folder to look for names
Returns:
Union[Path, str]:
path to data task directory if name is False
name of task if name is True
"""
if models:
t = os.getenv("det_models")
else:
t = os.getenv("det_data")
if t is None:
raise ValueError("Framework not configured correctly! "
"Please set `det_data` and `det_models` as environment variables!")
det_data = Path(t)
all_tasks = [d.stem for d in det_data.iterdir() if d.is_dir() and "Task" in d.name]
if task_id.startswith("Task"):
task_id = task_id[4:]
all_tasks = [tn[4:] for tn in all_tasks]
task_options_exact = [d for d in all_tasks if task_id in d]
task_number_id = [tn for tn in all_tasks if tn.split('_', 1)[0] == task_id]
task_name_id = [tn for tn in all_tasks if tn.split('_', 1)[1] == task_id]
if len(task_options_exact) == 1:
result = det_data / f"Task{task_options_exact[0]}"
elif len(task_number_id) == 1:
result = det_data / f"Task{task_number_id[0]}"
elif len(task_name_id) == 1:
result = det_data / f"Task{task_name_id[0]}"
else:
raise ValueError(f"Did not find task id {task_id}."
f"Options are: {all_tasks}")
if name:
result = result.stem
return result
def get_training_dir(model_dir: Pathlike, fold: int) -> Path:
"""
Find training dir from a specific model dir
Args:
model_dir: path to model dir e.g. ../Task12_LIDC/RetinaUNetV0
fold: fold to look for. if -1 look for consolidated dir
Returns:
Path: path to training dir
"""
model_dir = Path(model_dir)
identifier = f"fold{fold}" if fold != -1 else "consolidated"
candidates = [p for p in model_dir.iterdir() if p.is_dir() and identifier in p.stem]
if len(candidates) == 1:
return candidates[0]
else:
raise ValueError(f"Found wrong number of training dirs {candidates} in {model_dir}")
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import shutil
import numpy as np
import SimpleITK as sitk
from pathlib import Path
from typing import Dict, List, Sequence, Optional
from nndet.io.paths import Pathlike
from loguru import logger
from sklearn.model_selection import train_test_split
from nndet.io.paths import get_case_ids_from_dir
from nndet.io.load import save_json
from nndet.utils.clustering import seg2instances, remove_classes, reorder_classes
__all__ = ["maybe_split_4d_nifti", "instances_from_segmentation", "sitk_copy_metadata"]
def maybe_split_4d_nifti(source_file: Path, output_folder: Path):
"""
Process a single nifti file
if 3D File: copies file to target location
if 4D File: splits into multiple 3D files and append _0000 ending to indicate channels
Args:
source_file (Path): path to source file
output_folder (Path): path to target directory
Raises
TypeError: Data must be 3D or 4D
"""
img_itk = sitk.ReadImage(str(source_file))
dim = img_itk.GetDimension()
filename = source_file.name
if dim == 3:
# -7 cuts the .nii.gz part
shutil.copy(str(source_file), str(output_folder / (filename[:-7] + "_0000.nii.gz")))
return
elif dim == 4:
imgs_splitted = split_4d_itk(img_itk)
for idx, img in enumerate(imgs_splitted):
sitk.WriteImage(img, str(output_folder / (filename[:-7] + "_%04.0d.nii.gz" % idx)))
else:
raise TypeError(f"Unexpected dimensionality: {dim} of file {source_file}, cannot split")
def split_4d_itk(img_itk: sitk.Image) -> List[sitk.Image]:
"""
Helper function to split 4d itk images into multiple 3 images
Args:
img_itk: 4D input image
Returns:
List[sitk.Image]: 3d output images
"""
img_npy = sitk.GetArrayFromImage(img_itk)
spacing = img_itk.GetSpacing()
origin = img_itk.GetOrigin()
direction = np.array(img_itk.GetDirection()).reshape(4, 4)
spacing = tuple(list(spacing[:-1]))
assert len(spacing) == 3
origin = tuple(list(origin[:-1]))
assert len(origin) == 3
direction = tuple(direction[:-1, :-1].reshape(-1))
assert len(direction) == 9
images_new = []
for i, t in enumerate(range(img_npy.shape[0])):
img = img_npy[t]
images_new.append(
create_itk_image_spatial_props(img, spacing, origin, direction))
return images_new
def create_itk_image_spatial_props(
data: np.ndarray, spacing: Sequence[float], origin: Sequence[float],
direction: Sequence[Sequence[float]]) -> sitk.Image:
"""
Create new sitk image and set spatial tags
Args:
data: data
spacing: spacing
origin: origin
direction: directiont
Returns:
sitk.Image: new image
"""
data_itk = sitk.GetImageFromArray(data)
data_itk.SetSpacing(spacing)
data_itk.SetOrigin(origin)
data_itk.SetDirection(direction)
return data_itk
def sitk_copy_metadata(img_source: sitk.Image, img_target: sitk.Image) -> sitk.Image:
"""
Copy metadata (spacing, origin, direction) from source to target image
Args
img_source: source image
img_target: target image
Returns:
SimpleITK.Image: target image with copied metadata
"""
raise RuntimeError("Deprecated")
spacing = img_source.GetSpacing()
img_target.SetSpacing(spacing)
origin = img_source.GetOrigin()
img_target.SetOrigin(origin)
direction = img_source.GetDirection()
img_target.SetDirection(direction)
return img_target
def instances_from_segmentation(source_file: Path, output_folder: Path,
rm_classes: Sequence[int] = None,
ro_classes: Dict[int, int] = None,
subtract_one_of_classes: bool = True,
fg_vs_bg: bool = False,
file_name: Optional[str] = None
):
"""
1. Optionally removes classes from the segmentation (
e.g. organ segmentation's which are not useful for detection)
2. Optionally reorders the segmentation indices
3. Converts semantic segmentation to instance segmentation's via
connected components
Args:
source_file: path to semantic segmentation file
output_folder: folder where processed file will be saved
rm_classes: classes to remove from semantic segmentation
ro_classes: reorder classes before instances are generated
subtract_one_of_classes: subtracts one from the classes
in the instance mapping (detection networks assume
that classes start from 0)
fg_vs_bg: map all foreground classes to a single class to run
foreground vs background detection task.
file_name: name of saved file (without file type!)
"""
if subtract_one_of_classes and fg_vs_bg:
logger.info("subtract_one_of_classes will be ignored because fg_vs_bg is "
"active and all foreground classes ill be mapped to 0")
seg_itk = sitk.ReadImage(str(source_file))
seg_npy = sitk.GetArrayFromImage(seg_itk)
if rm_classes is not None:
seg_npy = remove_classes(seg_npy, rm_classes)
if ro_classes is not None:
seg_npy = reorder_classes(seg_npy, ro_classes)
instances, instance_classes = seg2instances(seg_npy)
if fg_vs_bg:
num_instances_check = len(instance_classes)
seg_npy[seg_npy > 0] = 1
instances, instance_classes = seg2instances(seg_npy)
num_instances = len(instance_classes)
if num_instances != num_instances_check:
logger.warning(f"Lost instance: Found {num_instances} instances before "
f"fg_vs_bg but {num_instances_check} instances after it")
if subtract_one_of_classes:
for key in instance_classes.keys():
instance_classes[key] -= 1
if fg_vs_bg:
for key in instance_classes.keys():
instance_classes[key] = 0
seg_itk_new = sitk.GetImageFromArray(instances)
seg_itk_new = sitk_copy_metadata(seg_itk, seg_itk_new)
if file_name is None:
suffix_length = sum(map(len, source_file.suffixes))
file_name = source_file.name[:-suffix_length]
save_json({"instances": instance_classes}, output_folder / f"{file_name}.json")
sitk.WriteImage(seg_itk_new, str(output_folder / f"{file_name}.nii.gz"))
def create_test_split(splitted_dir: Pathlike,
num_modalities: int,
test_size: float = 0.3,
random_state: int = 0,
shuffle: bool = True,
):
"""
Helper function to create an artificial test split from the splitted data
Args:
splitted_dir: path to directory with splitted data. `imagesTr` and
`labelsTr` need to exist beforehand. `imagesTs` and `labelsTs`
will be created automatically.
num_modalities: number of modalities
test_size: size of test set, needs to be a value between 0 and 1
seed: seed for splitting
shuffle: shuffle data
"""
images_tr = Path(splitted_dir) / "imagesTr"
labels_tr = Path(splitted_dir) / "labelsTr"
images_ts = Path(splitted_dir) / "imagesTs"
labels_ts = Path(splitted_dir) / "labelsTs"
if not images_tr.is_dir():
raise ValueError(f"No dir with training images found {images_tr}")
if not labels_tr.is_dir():
raise ValueError(f"No dir with training labels found {labels_tr}")
images_ts.mkdir(parents=True, exist_ok=True)
labels_ts.mkdir(parents=True, exist_ok=True)
case_ids = sorted(get_case_ids_from_dir(images_tr, remove_modality=True))
logger.info(f"Found {len(case_ids)} to split")
train_ids, test_ids = train_test_split(
case_ids, test_size=test_size, random_state=random_state, shuffle=shuffle)
logger.info(f"Using {train_ids} for training and {test_ids} for testing.")
for cid in test_ids:
for modality in range(num_modalities):
shutil.move(images_tr / f"{cid}_{modality:04d}.nii.gz",
images_ts / f"{cid}_{modality:04d}.nii.gz")
shutil.move(labels_tr / f"{cid}.nii.gz", labels_ts / f"{cid}.nii.gz")
if (labels_tr / f"{cid}.json").is_file():
shutil.move(labels_tr / f"{cid}.json", labels_ts / f"{cid}.json")
from nndet.io.transforms.base import AbstractTransform
from nndet.io.transforms.instances import (
Instances2Boxes,
Instances2Segmentation,
FindInstances,
)
from nndet.io.transforms.utils import (
AddProps2Data,
NoOp,
FilterKeys,
)
from nndet.io.transforms.spatial import (
Mirror,
)
from typing import Any
import torch
class AbstractTransform(torch.nn.Module):
def __init__(self, grad: bool = False, **kwargs):
"""
Args:
grad: enable gradient computation inside transformation
"""
super().__init__()
self.grad = grad
def __call__(self, *args, **kwargs) -> Any:
"""
Call super class with correct torch context
Args:
*args: forwarded positional arguments
**kwargs: forwarded keyword arguments
Returns:
Any: transformed data
"""
if self.grad:
context = torch.enable_grad()
else:
context = torch.no_grad()
with context:
return super().__call__(*args, **kwargs)
\ No newline at end of file
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import numpy as np
from torch import Tensor
from typing import Dict, Union, Sequence, Tuple, Optional
from nndet.io.transforms.base import AbstractTransform
class FindInstances(AbstractTransform):
def __init__(self, instance_key: str, save_key: str = "present_instances", **kwargs):
super().__init__(grad=False)
self.instance_key = instance_key
self.save_key = save_key
def forward(self, **data) -> dict:
present_instances = []
for instance_element in data[self.instance_key].split(1):
tmp = instance_element.to(dtype=torch.int).unique(sorted=True)
tmp = tmp[tmp > 0]
present_instances.append(tmp)
data[self.save_key] = present_instances
return data
class Instances2Boxes(AbstractTransform):
def __init__(self, instance_key: str, map_key: str,
box_key: str, class_key: str, grad: bool = False,
present_instances: Optional[str] = None,
**kwargs):
"""
Convert instance segmentation to bounding boxes
Args
instance_key: key where instance segmentation is located
map_key: key where mapping from instances to classes is located
(should be a dict which keys(instances) to items(classes))
box_key: key where boxes should be saved
class_key: key where classes of instances will be saved
grad: enable gradient computation inside transformation
present_instances: key where precomputed present instances are
saved. If None it will compute the present instance new.
"""
super().__init__(grad=grad, **kwargs)
self.class_key = class_key
self.box_key = box_key
self.map_key = map_key
self.instance_key = instance_key
self.present_instances = present_instances
def forward(self, **data) -> dict:
"""
Extract boxes from instances
Args:
**data: batch dict
Returns:
dict: processed batch
"""
data[self.box_key] = []
data[self.class_key] = []
for batch_idx, instance_element in enumerate(data[self.instance_key].split(1)):
_present_instances = data[self.present_instances][batch_idx] if self.present_instances is not None else None
_boxes, instance_idx = instances_to_boxes(
instance_element, instance_element.ndim - 2, instances=_present_instances)
_classes = get_instance_class_from_properties(
instance_idx, data[self.map_key][batch_idx])
_classes = _classes.to(device=_boxes.device)
data[self.box_key].append(_boxes)
data[self.class_key].append(_classes)
return data
def instances_to_boxes(seg: Tensor,
dim: int = None,
instances: Optional[Sequence[int]] = None,
) -> Tuple[Tensor, Tensor]:
"""
Convert instance segmentation to bounding boxes (not batched)
Args
seg: instance segmentation of individual classes [..., dims]
dim: number of spatial dimensions to create bounding box for
(always start from the last dimension). If None, all dimensions are
used
Returns
Tensor: bounding boxes
(x1, y1, x2, y2, (z1, z2)) List[Tensor[N, dim * 2]]
Tensor: tuple with classes for bounding boxes
"""
if dim is None:
dim = seg.ndim
boxes = []
_seg = seg.detach()
if instances is None:
instances = _seg.unique(sorted=True)
instances = instances[instances > 0]
for _idx in instances:
instance_idx = (_seg == _idx).nonzero(as_tuple=False)
_mins = instance_idx[:, -3:].min(dim=0)[0]
_maxs = instance_idx[:, -3:].max(dim=0)[0]
box = [_mins[-dim] - 1, _mins[(-dim) + 1] - 1, _maxs[-dim] + 1, _maxs[(-dim) + 1] + 1]
if dim > 2:
box = box + [_mins[(-dim) + 2] - 1, _maxs[(-dim) + 2] + 1]
boxes.append(torch.tensor(box))
if boxes:
boxes = torch.stack(boxes)
else:
boxes = torch.tensor([[]])
return boxes.to(dtype=torch.float, device=seg.device), instances
def instances_to_boxes_np(
seg: np.ndarray,
dim: int = None,
instances: Optional[Sequence[int]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Convert instance segmentation to bounding boxes (not batched)
Args
seg: instance segmentation of individual classes [..., dims]
dim: number of spatial dimensions to create bounding box for
(always start from the last dimension). If None, all dimensions are
used
Returns
np.ndarray: bounding boxes
(x1, y1, x2, y2, (z1, z2)) List[Tensor[N, dim * 2]]
np.ndarray: tuple with classes for bounding boxes
"""
if dim is None:
dim = seg.ndim
boxes = []
if instances is None:
instances = np.unique(seg)
instances = instances[instances > 0]
for _idx in instances:
instance_idx = np.stack(np.nonzero(seg == _idx), axis=1)
_mins = np.min(instance_idx[:, -dim:], axis=0)
_maxs = np.max(instance_idx[:, -dim:], axis=0)
box = [_mins[-dim] - 1, _mins[(-dim) + 1] - 1, _maxs[-dim] + 1, _maxs[(-dim) + 1] + 1]
if dim > 2:
box = box + [_mins[(-dim) + 2] - 1, _maxs[(-dim) + 2] + 1]
boxes.append(np.array(box))
if boxes:
boxes = np.stack(boxes)
else:
boxes = np.array([[]])
return boxes, instances
def get_instance_class_from_properties(
instance_idx: torch.Tensor, map_dict: Dict[str, Union[str, int]]) -> Tensor:
"""
Extract instance classes form mapping dict
Args:
instance_idx: instance ids present in segmentaion
map_dict: dict mapping instance ids (keys) to classes
Returns:
Tensor: extracted instance classes
"""
instance_idx, _ = instance_idx.sort()
classes = [int(map_dict[str(int(idx.detach().item()))]) for idx in instance_idx]
return torch.tensor(classes, device=instance_idx.device)
def get_instance_class_from_properties_seq(
instance_idx: Sequence, map_dict: Dict[str, Union[str, int]]) -> Sequence:
"""
Extract instance classes form mapping dict
Args:
instance_idx: instance ids present in segmentaion
map_dict: dict mapping instance ids (keys) to classes
Returns:
Sequence[int]: extracted instance classes
"""
instance_idx = sorted(instance_idx)
classes = [int(map_dict[str(int(idx))]) for idx in instance_idx]
return classes
class Instances2Segmentation(AbstractTransform):
def __init__(self, instance_key: str, map_key: str, seg_key: str = None,
add_background: bool = True, grad: bool = False,
present_instances: Optional[str] = None,
):
"""
Convert instances to semantic segmentation
Args:
instance_key: key where instance segmentation is located
map_key: key where mapping from instances to classes is located
seg_key: key where segmentation should be saved; If None, the
instance key will be overwritten
add_background: adds +1 to classes from mapping for background
grad: enable gradient propagation through transformation
present_instances: key where precomputed present instances are
saved. If None it will compute the present instance new.
"""
super().__init__(grad=grad)
self.add_background = add_background
self.seg_key = seg_key if seg_key is not None else instance_key
self.map_key = map_key
self.instance_key = instance_key
self.present_instances = present_instances
def forward(self, **data) -> dict:
"""
Convert instance segmentation to semantic segmentation
Args:
**data: batch dict
Returns:
dict: processed batch
"""
semantic = torch.zeros_like(data[self.instance_key])
_present_instances = data[self.present_instances] if self.present_instances is not None else None
for batch_idx in range(semantic.shape[0]):
instances_to_segmentation(data[self.instance_key][batch_idx],
data[self.map_key][batch_idx],
add_background=self.add_background,
instance_idx=_present_instances[batch_idx],
out=semantic[batch_idx])
data[self.seg_key] = semantic
return data
def instances_to_segmentation(instances: Tensor,
mapping: Dict[str, Union[str, int]],
add_background: bool = True,
instance_idx: Optional[Sequence[int]] = None,
out: Tensor = None) -> Tensor:
"""
Convert instances to semantic segmentation
Args:
instances: instance segmentation; foreground classes > 0; [dims]
mapping: mapping from each instance to class
add_background: adds +1 to classes from mapping for background
Should be enabled if classes in mapping start from zero and
diabled otherwise
out: optional output tensor where results are saved
instance_idx: precomputed instance ids present in sample. If None
the instances ids will be computed
Returns:
Tensor: semantic segmentation
"""
mapping = {int(key): int(item) for key, item in mapping.items()}
if out is None:
out = torch.zeros_like(instances)
if instance_idx is None:
instance_idx = instances.unique(sorted=True)
instance_idx = instance_idx[instance_idx > 0]
for instance_id in instance_idx:
_cls = mapping[instance_id.item()]
if add_background:
_cls += 1
out[instances == instance_id] = _cls
return out
def instances_to_segmentation_np(instances: np.ndarray,
mapping: Dict[Union[str, int], Union[str, int]],
add_background: bool = True,
out: np.ndarray = None) -> np.ndarray:
"""
Convert instances to semantic segmentation
Args:
instances: instance segmentation; foreground classes > 0; [dims]
mapping: mapping from each instance to class
add_background: adds +1 to classes from mapping for background
Should be enabled if classes in mapping start from zero and
diabled otherwise
out: optional output tensor where results are saved
Returns:
Tensor: semantic segmentation
"""
mapping = {int(key): int(item) for key, item in mapping.items()}
if out is None:
out = np.zeros_like(instances)
instance_idx = np.unique(instances)
instance_idx = instance_idx[instance_idx > 0]
for instance_id in instance_idx:
_cls = mapping[instance_id]
if add_background:
_cls += 1
out[instances == instance_id] = _cls
return out
def get_bbox_np(seg: np.ndarray,
map_dict: Optional[Dict[Union[str, int], Union[str, int]]] = None,
**kwargs,
) -> dict:
"""
Get bounding boxes and mapping from instances to classes
Args:
seg: instance segmentation [1, dims]
mapping: define mapping from instance ids to classes
Returns:
dict: extracted boxes and classes
`boxes` (np.ndarray): bounding boxes [N, dims * 2]
`classes` (np.ndarray): classes (in same order as boxes) [N]
"""
if map_dict is not None:
map_dict = {str(key): str(item) for key, item in map_dict.items()}
result = {}
boxes, instance_idx = instances_to_boxes_np(seg[0], **kwargs)
result["boxes"] = boxes
if map_dict is not None:
box_classes = get_instance_class_from_properties_seq(instance_idx, map_dict)
result["classes"] = np.array(box_classes)
return result
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch import Tensor
from typing import Sequence, List
from nndet.io.transforms.base import AbstractTransform
class Mirror(AbstractTransform):
def __init__(self, keys: Sequence[str], dims: Sequence[int],
point_keys: Sequence[str] = (), box_keys: Sequence[str] = (),
grad: bool = False):
"""
Mirror Transform
Args:
keys: keys to mirror (first key must correspond to data for
shape information) expected shape [N, C, dims]
dims: dimensions to mirror (starting from the first spatial
dimension)
point_keys: keys where points for transformation are located
[N, dims]
box_keys: keys where boxes are located; following format
needs to be used (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
grad: enable gradient computation inside transformation
"""
super().__init__(grad=grad)
self.dims = dims
self.keys = keys
self.point_keys = point_keys
self.box_keys = box_keys
def forward(self, **data) -> dict:
"""
Implement transform functionality here
Args
data: dict with data
Returns
dict: dict with transformed data
"""
for key in self.keys:
data[key] = mirror(data[key], self.dims)
data_shape = data[self.keys[0]].shape
data_shapes = [tuple(data_shape[2:])] * data_shape[0]
for key in self.box_keys:
points = [boxes2points(b) for b in data[key]]
points = mirror_points(points, self.dims, data_shapes)
data[key] = [points2boxes(p) for p in points]
for key in self.point_keys:
data[key] = mirror_points(data[key], self.dims, data_shapes)
return data
def invert(self, **data) -> dict:
"""
Revert mirroring
Args:
**data: dict with data
Returns:
dict with re-transformed data
"""
return self(**data)
def mirror(data: torch.Tensor, dims: Sequence[int]) -> torch.Tensor:
"""
Mirror data at dims
Args
data: input data [N, C, spatial dims]
dims: dimensions to mirror starting from spatial dims
e.g. dim=(0,) mirror the first spatial dimension
Returns
torch.Tensor: tensor with mirrored dimensions
"""
dims = [d + 2 for d in dims]
return data.flip(dims)
def mirror_points(points: Sequence[torch.Tensor], dims: Sequence[int],
data_shapes: Sequence[Sequence[int]]) -> List[torch.Tensor]:
"""
Mirror points along given dimensions
Args:
points: points per batch element [N, dims]
dims: dimensions to mirror
data_shapes: shape of data
Returns:
Tensor: transformed points [N, dims]
"""
cartesian_dims = points[0].shape[1]
homogeneous_points = points_to_homogeneous(points)
transformed = []
for points_per_image, data_shape in zip(homogeneous_points, data_shapes):
matrix = nd_mirror_matrix(cartesian_dims, dims, data_shape).to(points_per_image)
transformed.append(points_per_image @ matrix.transpose(0, 1))
return points_to_cartesian(transformed)
def nd_mirror_matrix(cartesian_dims: int, mirror_dims: Sequence[int],
data_shape: Sequence[int]) -> torch.Tensor:
"""
Create n dimensional matrix to for mirroring
Args:
cartesian_dims: number of cartesian dimensions
mirror_dims: dimensions to mirror
data_shape: shape of image
Returns:
Tensor: matrix for mirroring in homogeneous coordinated,
[cartesian_dims + 1, cartesian_dims + 1]
"""
mirror_dims = tuple(mirror_dims)
data_shape = list(data_shape)
homogeneous_dims = cartesian_dims + 1
mat = torch.eye(homogeneous_dims, dtype=torch.float)
# reflection
mat[[mirror_dims] * 2] = -1
# add data shape to axis which were reflected
self_tensor = torch.zeros(cartesian_dims, dtype=torch.float)
index_tensor = torch.Tensor(mirror_dims).long()
src_tensor = torch.tensor([1] * len(mirror_dims), dtype=torch.float)
offset_mask = self_tensor.scatter_(0, index_tensor, src_tensor)
mat[:-1, -1] = offset_mask * torch.tensor(data_shape)
return mat
def points_to_homogeneous(points: Sequence[torch.Tensor]) -> List[torch.Tensor]:
"""
Transforms points from cartesian to homogeneous coordinates
Args:
points: list of points to transform [N, dims] where N is the number
of points and dims is the number of spatial dimensions
Returns
torch.Tensor: the batch of points in homogeneous coordinates [N, dim + 1]
"""
return [torch.cat([p, torch.ones(p.shape[0], 1).to(p)], dim=1) for p in points]
def points_to_cartesian(points: Sequence[torch.Tensor]) -> List[torch.Tensor]:
"""
Transforms points in homogeneous coordinates back to cartesian
coordinates.
Args:
points: homogeneous points [N, in_dims], N number of points,
in_dims number of input dimensions (spatial dimensions + 1)
Returns:
List[Tensor]]: cartesian points [N, in_dims] = [N, dims]
"""
return [p[..., :-1] / p[..., -1][:, None] for p in points]
def boxes2points(boxes: Tensor) -> Tensor:
"""
Convert boxes to points
Args:
boxes: (x1, y1, x2, y2, (z1, z2))[N, dims *2]
Returns:
Tensor: points [N * 2, dims]
"""
if boxes.shape[1] == 4:
idx0 = [0, 1]
idx1 = [2, 3]
else:
idx0 = [0, 1, 4]
idx1 = [2, 3, 5]
points0 = boxes[:, idx0]
points1 = boxes[:, idx1]
return torch.cat([points0, points1], dim=0)
def points2boxes(points: Tensor) -> Tensor:
"""
Convert points to boxes
Args:
points: boxes need to be order as specified
order: [point_box_0, ... point_box_N/2] * 4
format of points: (x, y(, z)))[N, dims]
Returns:
Tensor: bounding boxes [N / 2, dims * 2]
"""
if points.nelement() > 0:
points0, points1 = points.split(points.shape[0] // 2)
boxes = torch.zeros(points.shape[0] // 2, points.shape[1] * 2).to(
device=points.device, dtype=points.dtype)
boxes[:, 0] = torch.min(points0[:, 0], points1[:, 0])
boxes[:, 1] = torch.min(points0[:, 1], points1[:, 1])
boxes[:, 2] = torch.max(points0[:, 0], points1[:, 0])
boxes[:, 3] = torch.max(points0[:, 1], points1[:, 1])
if boxes.shape[1] == 6:
boxes[:, 4] = torch.min(points0[:, 2], points1[:, 2])
boxes[:, 5] = torch.max(points0[:, 2], points1[:, 2])
return boxes
else:
return torch.tensor([]).view(-1, points.shape[1] * 2).to(points)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Hashable, Mapping, Sequence
from nndet.io.transforms.base import AbstractTransform
class AddProps2Data(AbstractTransform):
def __init__(self, props_key: str, key_mapping: Mapping[str, str], **kwargs):
"""
Move properties from property dict to data dict
Args
props_key: key where properties and :param:`map_key` key is located;
key_mapping: maps properties(key) to new keys in data dict(item)
"""
super().__init__(grad=False, **kwargs)
self.key_mapping = key_mapping
self.props_key = props_key
def forward(self, **data) -> dict:
"""
Move keys from properties to data
Args:
**data: batch dict
Returns:
dict: updated batch
"""
props = data[self.props_key]
for source, target in self.key_mapping.items():
data[target] = [p[source] for p in props]
return data
class NoOp(AbstractTransform):
def __init__(self, grad: bool = False):
"""
Forward input without change
Args:
grad: propagate gradient through transformation
"""
super().__init__(grad=grad)
def forward(self, **data) -> dict:
"""
NoOp
"""
return data
def invert(self, **data) -> dict:
"""
NoOp
"""
return data
class FilterKeys(AbstractTransform):
def __init__(self, keys: Sequence[Hashable]):
super().__init__(grad=False)
self.keys = keys
def forward(self, **data) -> dict:
return {k: data[k] for k in self.keys}
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List
from loguru import logger
from collections import OrderedDict
from pathlib import Path
from nndet.io.load import load_pickle
from nndet.io.paths import get_case_ids_from_dir, get_case_id_from_path, Pathlike
def get_np_paths_from_dir(directory: Pathlike) -> List[str]:
"""
First looks for npz files inside dir. If no files are found, it looks
for npy files.
Args:
directory: path to folder
Raises:
RuntimeError: raised if no npy and no npz files are found
Returns:
List[str]: paths to files
"""
case_paths = get_case_ids_from_dir(
Path(directory), remove_modality=False, join=True, pattern="*.npy")
if not case_paths:
logger.info(f"Did not find any npy files, looking for npz files. Folder: {directory}")
case_paths = get_case_ids_from_dir(
Path(directory), remove_modality=False, join=True, pattern="*.npz")
if not case_paths:
logger.error(f"Did not find any npz files.")
raise RuntimeError(f"Did not find any npz files. Folder: {directory}")
case_paths = [f for f in case_paths if "_seg" not in f]
case_paths.sort()
return case_paths
def load_dataset(folder: Pathlike) -> dict:
"""
Load dataset (path and properties, NOT the actual data) and
save them into dict by their path
Args:
folder: folder to look for data
Raises:
RuntimeError: data needs to be provided in npy or npz format
Returns:
dict: loaded data
"""
folder = Path(folder)
case_identifiers = get_np_paths_from_dir(folder)
dataset = OrderedDict()
for c in case_identifiers:
dataset[c] = OrderedDict()
dataset[c]['data_file'] = str(folder / f"{c}.npy")
dataset[c]['seg_file'] = str(folder / f"{c}_seg.npy")
dataset[c]['properties_file'] = str(folder / f"{c}.pkl")
dataset[c]['boxes_file'] = str(folder / f"{c}_boxes.pkl")
return dataset
def load_dataset_id(folder: Pathlike) -> dict:
"""
Load dataset (path and properties, NOT the actual data) and
save them into dict by their identifier
Args:
folder: folder to look for data
Raises:
RuntimeError: data needs to be provided in npy or npz format
Returns:
dict: loaded data
"""
folder = Path(folder)
case_paths = get_np_paths_from_dir(folder)
case_ids = [get_case_id_from_path(c, remove_modality=False) for c in case_paths]
dataset = OrderedDict()
for c in case_ids:
dataset[c] = OrderedDict()
dataset[c]['data_file'] = str(folder / f"{c}.npy")
dataset[c]['data_file'] = str(folder / f"{c}.npy")
dataset[c]['seg_file'] = str(folder / f"{c}_seg.npy")
dataset[c]['properties_file'] = str(folder / f"{c}.pkl")
dataset[c]['boxes_file'] = str(folder / f"{c}_boxes.pkl")
return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment