Commit 7246044d authored by mibaumgartner's avatar mibaumgartner
Browse files

Merge remote-tracking branch 'origin/master' into main

parents fcec502f 6f4c3333
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import gc
import subprocess as sp
import math
import copy
from abc import ABC, abstractmethod
from functools import partial, reduce
from typing import Sequence, Union, Callable, Tuple
from contextlib import contextmanager
from loguru import logger
from nndet.arch.abstract import AbstractModel
"""
This is just a first prototype to estimate VRAM consumption for different GPUs
I hope to update this soon.
"""
def b2mb(x): return x / (2**20)
def mb2b(x): return x * (2**20)
# remove 11mb from target memory to have a little wiggle room
# (sometimes that amount was blocked on my GPU even though nothing was running)
ARCHS = {
"RTX2080TI": 11523260416 - int(mb2b(11))
}
# this is just an esitmation ... probably depend on the cuda version too
CUDA_CONTEXT = {
"none": 0,
"RTX2080TI": int(mb2b(910))
}
class MemoryEstimator(ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.batch_size = None
@abstractmethod
def estimate(self, *args, **kwargs):
raise NotImplementedError
class MemoryEstimatorDetection(MemoryEstimator):
def __init__(self,
target_mem: Union[float, str] = "RTX2080TI",
gpu_id: int = 0,
context: Union[float, str] = "RTX2080TI",
offset: int = mb2b(768),
batch_size: int = 1,
mixed_precision: bool = True):
"""
Estimate memory needed for training a specific network
Args:
target_mem: memory of target card (can be higher than
currently used card). Defaults to "RTX2080TI".
gpu_id: GPU id to use for estimation. Defaults to 0.
context: Memory which is reserved for cuda context. Depends on
CUDA version and GPU. Defaults to "RTX2080TI".
offset: Additional safety offset because memory consuption
can fluctuate a bit during training. Defaults to 1024mb.
batch_size: batch size to use for estimation. Defaults to 1.
"""
super().__init__()
if isinstance(context, str):
self.context = CUDA_CONTEXT[context]
else:
self.context = context
self.offset = offset
self.block_mem_tensor = None
if isinstance(target_mem, str):
self.target_mem = ARCHS[target_mem]
else:
self.target_mem = target_mem
self.gpu_id = gpu_id
self.batch_size = batch_size
self.mixed_precision = mixed_precision
def create_offset_tensor_on_GPU(self) -> torch.Tensor:
device = f"cuda:{self.gpu_id}"
tensor_mem = torch.rand(1, dtype=float, requires_grad=False, device=device).element_size()
return torch.rand(math.ceil(self.offset / tensor_mem), dtype=float,
requires_grad=False, device=device)
def estimate(self,
min_shape: Sequence[int],
target_shape: Sequence[int],
network: AbstractModel,
optimizer_cls: Callable = torch.optim.Adam,
in_channels: int = None,
num_instances: int = 1,
) -> Tuple[int, bool]:
if in_channels is not None:
min_shape = [in_channels, *min_shape]
target_shape = [in_channels, *target_shape]
# all_mem - reserved_mem[misc + context] + context
available_mem = torch.cuda.get_device_properties(self.gpu_id).total_memory - \
smi_memory_allocated(self.gpu_id) + self.context
logger.info(
f"Found available gpu memory: {available_mem} bytes / {b2mb(available_mem)} mb "
f"and estimating for {self.target_mem} bytes / {b2mb(self.target_mem)}")
# if available_mem >= self.target_mem:
res = self._estimate_mem_available(
min_shape=min_shape,
target_shape=target_shape,
network=copy.deepcopy(network),
optimizer_cls=optimizer_cls,
num_instances=num_instances,
)
# else:
# res = self._estimate_mem_not_available(
# min_shape=min_shape, target_shape=target_shape,
# network=network, optimizer_cls=optimizer_cls,
# num_instances=num_instances,
# )
del self.block_mem_tensor
self.block_mem_tensor = None
torch.cuda.empty_cache()
gc.collect()
return res
def _estimate_mem_available(self,
min_shape: Sequence[int],
target_shape: Sequence[int],
network: AbstractModel,
optimizer_cls: Callable = torch.optim.Adam,
num_instances: int = 1,
) -> Tuple[int, bool]:
logger.info("Estimating in memory.")
fixed, dynamic = self.measure(shape=target_shape,
network=network,
optimizer_cls=optimizer_cls,
num_instances=num_instances,
)
estimated_mem = fixed + dynamic
return estimated_mem, estimated_mem < self.target_mem
def _estimate_mem_not_available(self,
min_shape: Sequence[int],
target_shape: Sequence[int],
network: AbstractModel,
optimizer_cls: Callable = torch.optim.Adam,
num_instances: int = 1,
) -> Tuple[int, bool]:
raise NotImplementedError("!!!!!This needs more refinement!!!!")
logger.info("Extrapolating memory consumption.")
assert all([t >= m for t, m in zip(target_shape, min_shape)])
fixed_mem, dyn_mem = self.measure(shape=min_shape,
network=network,
optimizer_cls=optimizer_cls,
num_instances=num_instances,
)
ratios = [t / m for t, m in zip(target_shape, min_shape)]
scale = reduce((lambda x, y: x * y), ratios)
estimated_dyn_mem = dyn_mem * scale
estimated_mem = estimated_dyn_mem + fixed_mem
if self.context is not None:
estimated_mem += self.context
return estimated_mem, estimated_mem < self.target_mem
def measure(self,
shape: Sequence[int],
network: AbstractModel,
optimizer_cls: Callable = torch.optim.Adam,
num_instances: int = 1,
):
device = torch.device("cuda", self.gpu_id)
logger.info(f"Estimating on {device} with shape {shape} and "
f"batch size {self.batch_size} and num_instances {num_instances}")
try:
loss = None
opt = None
inp = None
with cudnn_deterministic():
torch.cuda.reset_peak_memory_stats()
network = network.to(device)
# torch.cuda.memory_allocated
empty_mem = torch.cuda.memory_reserved()
scaler = torch.cuda.amp.GradScaler()
opt = optimizer_cls(network.parameters())
boxes = [[0, 0, 2, 2]]
if len(shape) == 4: # in_channels + spatial dims
boxes[0].extend((0, 2))
block_tensor = self.create_offset_tensor_on_GPU().to(device=device)
import time
time.sleep(1)
for _ in range(10):
opt.zero_grad()
inp = {"images": torch.rand((self.batch_size, *shape), device=device, dtype=torch.float),
"targets": {
"target_boxes": [torch.tensor(
boxes, device=device, dtype=torch.float).repeat(num_instances, 1)
for _ in range(self.batch_size)],
"target_classes": [torch.tensor(
[0] * num_instances, device=device, dtype=torch.float)
for _ in range(self.batch_size)],
"target_seg": torch.zeros(
(self.batch_size, *shape[1:]), device=device, dtype=torch.float),
}}
fixed_mem = torch.cuda.memory_reserved()
with torch.cuda.amp.autocast():
loss_dict, _ = network.train_step(
images=inp["images"],
targets=inp["targets"],
evaluation=False,
batch_num=0,
)
loss = sum(loss_dict.values())
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
dyn_mem = torch.cuda.memory_reserved()
except (RuntimeError,) as e:
logger.info(f"Caught error (If out of memory error do not worry): {e}")
empty_mem = 0
fixed_mem = float('Inf')
dyn_mem = float('Inf')
finally:
del loss
del opt
del inp
del block_tensor
network.cpu()
torch.cuda.empty_cache()
gc.collect()
logger.info(f"Measured: {b2mb(empty_mem)} mb empty, "
f"{b2mb(fixed_mem)} mb fixed, "
f"{b2mb(dyn_mem)} mb dynamic")
return fixed_mem - empty_mem, dyn_mem - fixed_mem
def num_gpus():
"""
Number of GPUs independent of visible devices
"""
return str(sp.check_output(["nvidia-smi", "-L"])).count('UUID')
def smi_memory_allocated(gpu_id: int = 0) -> int:
"""
Read memory consumption from nvidia smi
Returns:
int: measured GPU memory in bytes
"""
reading = int(sp.check_output(
['nvidia-smi', '--query-gpu=memory.used',
'--format=csv,nounits,noheader'], encoding='utf-8').split('\n')[gpu_id])
return mb2b(reading)
class Tracemalloc():
def __init__(self, measure_fn):
super().__init__()
self.measure_fn = measure_fn
def __enter__(self):
self.begin = self.measure_fn()
return self
def __exit__(self, *exc):
self.end = self.measure_fn()
self.used = self.end - self.begin
logger.info(f"Measured {self.used} byte GPU mem consumption")
class TorchTracemalloc(Tracemalloc):
def __init__(self, gpu_id: int = None):
if gpu_id is not None:
fn = partial(torch.cuda.memory_reserved, device=gpu_id)
else:
fn = torch.cuda.memory_reserved
super().__init__(measure_fn=fn)
def __enter__(self):
super().__enter__()
torch.cuda.reset_peak_memory_stats() # reset the peak to zero
return self
def __exit__(self, *exc):
super().__exit__()
self.peak = torch.cuda.max_memory_allocated()
self.peaked = self.peak - self.begin
logger.info(f"Measured peak {self.used} byte GPU mem consumption")
class SmiTracemalloc(Tracemalloc):
def __init__(self, gpu_id: int = None):
if gpu_id is not None:
fn = partial(smi_memory_allocated, gpu_id=gpu_id)
else:
fn = smi_memory_allocated
super().__init__(measure_fn=fn)
@contextmanager
def cudnn_deterministic():
old_value = torch.backends.cudnn.deterministic
old_value_benchmark = torch.backends.cudnn.benchmark
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
try:
yield None
finally:
torch.backends.cudnn.deterministic = old_value
torch.backends.cudnn.benchmark = old_value_benchmark
from typing import Mapping, Type
from nndet.utils.registry import Registry
from nndet.planning.experiment.base import PlannerType, AbstractPlanner
PLANNER_REGISTRY: Mapping[str, Type[PlannerType]] = Registry()
from nndet.planning.experiment.v001 import D3V001
from __future__ import annotations
import os
from pathlib import Path
from itertools import repeat
from multiprocessing import Pool
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, Optional, List, TypeVar
import numpy as np
from loguru import logger
from nndet.io.load import load_pickle, save_pickle
from nndet.io.paths import (
get_case_ids_from_dir,
get_paths_from_splitted_dir,
)
from nndet.planning.architecture.abstract import ArchitecturePlannerType
from nndet.preprocessing.preprocessor import PreprocessorType
from nndet.planning.experiment.utils import run_create_label_preprocessed
class AbstractPlanner(ABC):
def __init__(self,
preprocessed_output_dir: os.PathLike,
):
"""
Base class for experiment planning
Args:
preprocessed_output_dir: path to directory where preprocessed
data will be saved
"""
super().__init__()
self.preprocessed_output_dir = Path(preprocessed_output_dir)
self.transpose_forward = None
self.transpose_backward = None
self.anisotropy_threshold = 3
self.resample_anisotropy_threshold = 3
self.target_spacing_percentile = 50
self.data_properties = self.load_data_properties()
@abstractmethod
def plan_experiment(self,
model_name: str,
model_cfg: Dict,
) -> List[str]:
"""
Plan the whole experiment
Args:
model_name: name of model to plan for
model_cfg: config to initialize model for VRAM estimation
Returns:
List: identifiers of created plans
"""
raise NotImplementedError
@abstractmethod
def create_architecture_planner(self,
model_name: str,
model_cfg: dict,
mode: str,
) -> ArchitecturePlannerType:
"""
Create Architecture planner
Args:
model_name: name of model to plan for
model_cfg: config to initialize model for VRAM estimation
mode: current mode of experiment planner
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def create_preprocessor(plan: Dict) -> PreprocessorType:
"""
Create Preprocessor
"""
raise NotImplementedError
@abstractmethod
def determine_forward_backward_permutation(self, mode: str):
"""
Permute dimensions of input. Results should be saved into
:param:`transpose_forward` and :param:`transpose_backward`
Args:
mode: define current operation mode. Typically one of
'2d' | '3d' | '3dlr1'
Raises:
NotImplementedError: Should be overwritten in subcalsses
"""
raise NotImplementedError
@abstractmethod
def determine_target_spacing(self, mode: str) -> np.ndarray:
"""
Determine target spacing.
Args:
mode: define current operation mode. Typically one of
'2d' | '3d' | '3dlr1'
Same as nnUNet v21
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
"""
raise NotImplementedError
def load_data_properties(self):
"""
Load properties from analysis of dataset
Returns:
dict: loaded properties
"""
data_properties_path = self.preprocessed_output_dir / "properties" / "dataset_properties.pkl"
assert data_properties_path.is_file(), "data properties need to exist. Run data analysis first"
data_properties = load_pickle(data_properties_path)
return data_properties
def get_data_identifier(self, mode: str):
"""
By default each plan is associated with its own folder
If only the architecture changed, this can be overwritten
to use the data from a different plan (useful for dev)
"""
return f"{self.__class__.__name__}_{mode}"
def plan_base(self, mode: str) -> Dict:
"""
Create the base plan
Args:
mode: define current operation mode. Typically one of
'2d' | '3d' | '3dlr1'
Returns:
Dict: plan with base attributes
'mode': selected mode for plan
`target_spacing`: target to resample data
`normalization_schemes` normalization type for each modality
`use_mask_for_norm`: use mask for norm
`anisotropy_threshold`: threshold used to trigger anisotropy
settings
`resample_anisotropy_threshold`: threshold to trigger different
resampling schemes
`target_spacing_percentile`: target spacing percentile
used to create target spacing
`dim`: dimensionality of data (2 or 3)
`transpose_forward`: transpose forward order
`transpose_backward`: transpose back order
`list_of_npz_files`: files used to preprocessing
"""
use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
logger.info(f"Are we using the nonzero maks for normalization? {use_nonzero_mask_for_normalization}")
target_spacing = self.determine_target_spacing(mode=mode)
logger.info(f"Base target spacing is {target_spacing}")
self.determine_forward_backward_permutation(mode=mode)
normalization_schemes = self.determine_normalization()
logger.info(f"Normalization schemes {normalization_schemes}")
plan = {
'mode': mode,
'target_spacing': target_spacing,
'normalization_schemes': normalization_schemes,
'use_mask_for_norm': use_nonzero_mask_for_normalization,
'anisotropy_threshold': self.anisotropy_threshold,
'resample_anisotropy_threshold': self.resample_anisotropy_threshold,
'target_spacing_percentile': self.target_spacing_percentile,
'dim': self.data_properties['dim'],
"num_modalities": len(list(self.data_properties['modalities'].keys())),
"all_classes": self.data_properties['all_classes'],
"num_classes": len(self.data_properties['all_classes']),
'transpose_forward': self.transpose_forward,
'transpose_backward': self.transpose_backward,
'dataset_properties': self.data_properties,
"planner_id": self.__class__.__name__,
}
return plan
def plan_base_stage(self,
base_plan: Dict,
model_name: str,
model_cfg: dict,
):
"""
Plan the first stage of training
Args:
base_plan: basic plan
model_name: name of model to plan for
model_cfg: config to initialize model for VRAM estimation
Returns:
dict: properties of stage
`patch_size`
`batch_size`
`architecture` (dict): kwargs for architecture
`current_spacing`
`original_spacing`
`median_shape_transposed`
`do_dummy_2D_data_aug`
"""
target_spacing = base_plan['target_spacing']
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
median_shape = np.median(np.vstack(new_shapes), 0)
logger.info(f"The median shape of the dataset is {median_shape}")
max_shape = np.max(np.vstack(new_shapes), 0)
logger.info(f"The max shape in the dataset is {max_shape}")
min_shape = np.min(np.vstack(new_shapes), 0)
logger.info(f"The min shape in the dataset is {min_shape}")
target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
median_shape_transposed = np.array(median_shape)[self.transpose_forward]
logger.info(f"The transposed median shape of the dataset is {median_shape_transposed}")
architecture_planner = self.create_architecture_planner(
model_name=model_name,
model_cfg=model_cfg,
mode=base_plan["mode"],
)
architecture_plan = architecture_planner.plan(
target_spacing_transposed=target_spacing_transposed,
median_shape_transposed=median_shape_transposed,
transpose_forward=self.transpose_forward,
mode=base_plan["mode"],
)
patch_size = architecture_plan["patch_size"]
do_dummy_2d_data_aug = (max(patch_size) / min(patch_size)) > self.anisotropy_threshold
base_plan.update(architecture_plan)
base_plan["target_spacing_transposed"] = target_spacing_transposed
base_plan["median_shape_transposed"] = median_shape_transposed
base_plan["do_dummy_2D_data_aug"] = do_dummy_2d_data_aug
return base_plan
def determine_postprocessing(self, mode: str) -> dict:
"""
Placeholder for the future
Args:
mode: define current operation mode. Typically one of
'2d' | '3d' | '3dlr1'
Deprecated version returned:
'keep_only_largest_region'
'min_region_size_per_class'
'min_size_per_class'
"""
logger.warning("No planning for post-processing implemented.")
return {}
def determine_normalization(self) -> Dict[int, str]:
"""
Determine normalization scheme for data
Returns:
Dict[int, str]: integer index represents modality and string is
either `CT` or `nonCT`
"""
schemes = OrderedDict()
modalities = self.data_properties['modalities']
num_modalities = len(list(modalities.keys()))
for i in range(num_modalities):
if modalities[i] == "CT":
schemes[i] = "CT"
elif modalities[i] == "CT2":
schemes[i] = "CT2"
else:
schemes[i] = "nonCT"
return schemes
def determine_whether_to_use_mask_for_norm(self) -> Dict[int, bool]:
"""
Determine if only foreground values should be used for normalization for all modalities
Returns:
Dict[int, bool]: result for each modality
"""
# only use the nonzero mask for normalization of the cropping based on it resulted in a decrease in
# image size (this is an indication that the data is something like brats/isles and then we want to
# normalize in the brain region only)
modalities = self.data_properties['modalities']
num_modalities = len(list(modalities.keys()))
use_mask_for_norm = OrderedDict()
for i in range(num_modalities):
if "CT" in modalities[i]:
use_mask_for_norm[i] = False
else:
all_size_reductions = list(self.data_properties["size_reductions"].values())
if np.median(all_size_reductions) < 3 / 4.:
logger.info("using nonzero mask for normalization")
use_mask_for_norm[i] = True
else:
logger.info("not using nonzero mask for normalization")
use_mask_for_norm[i] = False
return use_mask_for_norm
def save_plan(self, plan: dict, mode: str) -> str:
"""
Save plan
Args:
mode: plan mode
Return:
str: plan identifier
"""
self.preprocessed_output_dir.mkdir(
parents=True,
exist_ok=True,
)
identifier = f"{self.__class__.__name__}_{mode}"
save_pickle(plan, self.preprocessed_output_dir / f"{identifier}.pkl")
return identifier
def run_preprocessing(
self,
cropped_data_dir: os.PathLike,
plan: dict,
num_processes: int,
):
"""
Runs data preprocessing
Args:
cropped_data_dir: base cropped dir
plan: plan to use for preprocessing
num_processes: number of processes to use for preprocessing
"""
preprocessor = self.create_preprocessor(plan=plan)
preprocessor.run(
target_spacings=[plan["target_spacing"]],
identifiers=[plan["data_identifier"]],
cropped_data_dir=Path(cropped_data_dir),
preprocessed_output_dir=self.preprocessed_output_dir,
num_processes=num_processes,
)
self.create_labels_tr_preprocessed(
preprocessed_plan_dir=self.preprocessed_output_dir / plan["data_identifier"],
dim=3,
num_processes=num_processes,
)
@staticmethod
def create_labels_tr_preprocessed(
preprocessed_plan_dir: Path,
dim: int,
num_processes: int = 6,
):
"""
Creates labels for visualization and analysis purposes from
preprocessed data
Args:
preprocessed_plan_dir: path to preprocessed plan dir
dim: number of spatial dimensions
num_processes: number of processed to use
"""
source_dir = preprocessed_plan_dir / "imagesTr"
target_dir = preprocessed_plan_dir / "labelsTr"
target_dir.mkdir(parents=True, exist_ok=True)
case_ids = get_case_ids_from_dir(source_dir,
remove_modality=False,
pattern="*.npz",
)
logger.info('Preparing preprocessed evaluation labels')
if num_processes > 0:
with Pool(processes=num_processes) as p:
p.starmap(run_create_label_preprocessed,
zip(repeat(source_dir),
case_ids,
repeat(dim),
repeat(target_dir),
)
)
else:
for cid in case_ids:
run_create_label_preprocessed(source_dir, cid, dim, target_dir)
@classmethod
def run_preprocessing_test(cls,
preprocessed_output_dir: os.PathLike,
splitted_4d_output_dir: os.PathLike,
plan: dict,
num_processes: int = 0,
):
"""
Run preprocessing of test data
Args:
splitted_4d_output_dir: base dir of splitted data
plan: plan to use for preprocessing
num_processes: number of processes to use for preprocessing
"""
logger.info("Running preprocessing of test cases")
splitted_4d_output_dir = Path(splitted_4d_output_dir)
target_dir = Path(preprocessed_output_dir) / plan["data_identifier"] / "imagesTs"
target_dir.mkdir(parents=True, exist_ok=True)
cases_processed = get_case_ids_from_dir(
target_dir,
remove_modality=False,
pattern="*.npz",
)
cases = get_paths_from_splitted_dir(
num_modalities=plan["num_modalities"],
splitted_4d_output_dir=splitted_4d_output_dir,
test=True,
labels=False,
remove_ids=cases_processed,
)
logger.info(f"Found {len(cases)} cases for preprocssing in {splitted_4d_output_dir} "
f"and {len(cases_processed)} alrady processed cases.")
preprocessor = cls.create_preprocessor(plan=plan)
if num_processes > 0:
with Pool(processes=num_processes) as p:
p.starmap(preprocessor.run_test,
zip(cases,
repeat(plan["target_spacing"]),
repeat(target_dir),
)
)
else:
for c in cases:
preprocessor.run_test(c, plan["target_spacing"], target_dir)
PlannerType = TypeVar('PlannerType', bound=AbstractPlanner)
import os
from nndet.core.boxes.ops_np import box_size_np
import numpy as np
from pathlib import Path
from loguru import logger
from itertools import repeat
from multiprocessing import Pool
from typing import Dict
from nndet.io.itk import load_sitk_as_array
from nndet.io.load import load_json, load_pickle
from nndet.io.paths import get_case_ids_from_dir
from nndet.io.transforms.instances import (
get_bbox_np,
instances_to_segmentation_np,
)
def create_label_case(
target_dir: Path,
case_id: str,
instances: np.ndarray,
mapping: Dict[int, int],
dim: int,
) -> None:
"""
Crete labels for evaluation and analysis purposes
Args:
target_dir: target dir to save labels
case_id: case identifier
instances: instance segmentation
mapping: map each instance id to a class (classes start from 0)
dim: spatial dimensions
"""
instances_save_path = target_dir / f"{case_id}_instances_gt.npz"
boxes_save_path = target_dir / f"{case_id}_boxes_gt.npz"
seg_save_path = target_dir / f"{case_id}_seg_gt.npz"
if instances_save_path.is_file() and boxes_save_path.is_file() and seg_save_path.is_file():
logger.warning(f"Skipping prepare label {case_id} because it already exists")
else:
logger.info(f"Preparing label {case_id}")
if instances.ndim == dim:
instances = instances[None]
np.savez_compressed(str(instances_save_path),
instances=instances, mapping=mapping,
)
res = get_bbox_np(instances, mapping, dim=dim)
np.savez_compressed(str(boxes_save_path), **res)
seg = instances_to_segmentation_np(instances, mapping)
np.savez_compressed(str(seg_save_path), seg=seg)
def create_labels(
preprocessed_output_dir: os.PathLike,
source_dir: os.PathLike,
num_processes: int = 6,
):
"""
Creates labels for visualization and analysis purposes from raw labels
Prepares: instance segmentation, bounding boxes, semantic segmentation
Args:
source_dir: base dir which containes labelsTr/labelsTs
dim: number of spatial dimensions
num_processes: number of processed to use
"""
source_dir = Path(source_dir)
for postfix in ["Tr", "Ts"]:
if (source_label_dir := source_dir / f"labels{postfix}").is_dir():
logger.info(f'Preparing {postfix} evaluation labels')
target_dir = Path(preprocessed_output_dir) / f"labels{postfix}"
target_dir.mkdir(parents=True, exist_ok=True)
case_ids = get_case_ids_from_dir(source_label_dir,
remove_modality=False,
pattern="*.json",
)
if num_processes > 0:
with Pool(processes=num_processes) as p:
p.starmap(run_create_label,
zip(repeat(source_label_dir),
case_ids,
repeat(3),
repeat(target_dir),
)
)
else:
for cid in case_ids:
run_create_label(source_label_dir, cid, 3, target_dir)
def run_create_label(source_label_dir: Path,
case_id: str,
dim: int,
target_dir: Path,
):
"""
Helper to run preparation with multiprocessing
Args:
source_label_dir: directory with labels
case_id: case id to process
dim: number of spatial dimensions
target_dir: directory to save results
"""
instances = load_sitk_as_array(source_label_dir / f"{case_id}.nii.gz")[0]
properties = load_json(source_label_dir / f"{case_id}.json")
if instances.ndim == dim:
instances = instances[None]
instances = instances.astype(np.int32)
mapping = {int(key): int(item) for key, item in properties["instances"].items()}
create_label_case(
target_dir=target_dir,
case_id=case_id,
instances=instances,
mapping=mapping,
dim=dim,
)
def run_create_label_preprocessed(
source_dir: Path,
case_id: str,
dim: int,
target_dir: Path,
):
"""
Helper to run preparation with multiprocessing
Args:
source_dir: directory with labels
case_id: case id to process
dim: number of spatial dimensions
target_dir: directory to save results
"""
instances = np.load(str(source_dir / f"{case_id}.npz"), mmap_mode="r")["seg"]
properties = load_pickle(source_dir / f"{case_id}.pkl")
mapping = {int(key): int(item) for key, item in properties["instances"].items()}
create_label_case(
target_dir=target_dir,
case_id=case_id,
instances=instances,
mapping=mapping,
dim=dim,
)
from typing import Dict, List, Sequence
import numpy as np
from loguru import logger
from nndet.ptmodule import MODULE_REGISTRY
from nndet.planning.experiment import PLANNER_REGISTRY, AbstractPlanner
from nndet.planning.estimator import MemoryEstimatorDetection
from nndet.planning.architecture.boxes import BoxC002
from nndet.preprocessing.preprocessor import GenericPreprocessor
from nndet.core.boxes.ops_np import box_size_np
from nndet.planning.architecture.boxes.utils import concatenate_property_boxes
@PLANNER_REGISTRY.register
class D3V001(AbstractPlanner):
def plan_experiment(self,
model_name: str,
model_cfg: Dict,
) -> List[str]:
"""
Plan the whole experiment (currently only one stage is supported)
(uses :func:`self.save_plans()` to save the results)
Args:
model_name: name of model to plan for
model_cfg: config to initialize model for VRAM estimation
Returns:
List: identifiers of created plans
"""
identifiers = []
# create full resolution 3d plan
mode = "3d"
plan_3d = self.plan_base(mode=mode)
plan_3d["network_dim"] = 3
plan_3d["dataloader_kwargs"] = {}
plan_3d["data_identifier"] = self.get_data_identifier(mode=mode)
plan_3d["postprocessing"] = self.determine_postprocessing(mode=mode)
plan_3d = self.plan_base_stage(
plan_3d,
model_name=model_name,
model_cfg=model_cfg,
)
# determine if additional low res model needs to be trained
plan_3d["trigger_lr1"] = self.trigger_low_res_model(
prev_res_patch_size=plan_3d["patch_size"],
transpose_forward=plan_3d["transpose_forward"],
)
identifiers.append(self.save_plan(plan=plan_3d, mode=plan_3d["mode"]))
if plan_3d["trigger_lr1"]:
logger.info("Triggered Low Resolution Model")
mode = "3dlr1"
plan_3dlr1 = self.plan_base(mode=mode)
plan_3dlr1["network_dim"] = 3
plan_3dlr1["dataloader_kwargs"] = {}
plan_3dlr1["data_identifier"] = self.get_data_identifier(mode=mode)
plan_3dlr1["postprocessing"] = self.determine_postprocessing(mode=mode)
plan_3dlr1 = self.plan_base_stage(
plan_3dlr1,
model_name=model_name,
model_cfg=model_cfg,
)
identifiers.append(self.save_plan(plan=plan_3dlr1, mode=plan_3dlr1["mode"]))
return identifiers
def create_architecture_planner(self,
model_name: str,
model_cfg: dict,
mode: str,
) -> BoxC002:
"""
Create Architecture planner
"""
estimator = MemoryEstimatorDetection()
architecture_planner = BoxC002(
preprocessed_output_dir=self.preprocessed_output_dir,
save_dir=self.preprocessed_output_dir / "analysis" / f"{self.__class__.__name__}_{mode}",
estimator=estimator,
network_cls=MODULE_REGISTRY.get(model_name),
model_cfg=model_cfg,
)
return architecture_planner
@staticmethod
def create_preprocessor(plan: Dict) -> GenericPreprocessor:
"""
Create Preprocessor
"""
preprocessor = GenericPreprocessor(
norm_scheme_per_modality=plan['normalization_schemes'],
use_mask_for_norm=plan['use_mask_for_norm'],
transpose_forward=plan['transpose_forward'],
intensity_properties=plan['dataset_properties']['intensity_properties'],
resample_anisotropy_threshold=plan['resample_anisotropy_threshold'],
)
return preprocessor
def determine_forward_backward_permutation(self, mode: str):
"""
Determine position of z direction (absolute position is defined by z_first)
Result is
saved into :param:`transpose_forward` and :param:`transpose_backward`
"""
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
target_spacing = self.determine_target_spacing(mode=mode)
new_sizes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
dims = len(target_spacing)
max_spacing_axis = np.argmax(target_spacing)
remaining_axes = [i for i in list(range(dims)) if i != max_spacing_axis]
# self.transpose_forward = remaining_axes + [max_spacing_axis] # y, x, z
self.transpose_forward = [max_spacing_axis] + remaining_axes # z, y, x
self.transpose_backward = [np.argwhere(np.array(
self.transpose_forward) == i)[0][0] for i in range(dims)]
def determine_target_spacing(self, mode: str) -> np.ndarray:
"""
Determine target spacing
Args:
mode: Current planning mode. Typically one of '2d' | '3d' | '3dlr1'
Raises:
RuntimeError: not supported mode (supported are 2d, 3d, 3dlrX)
Returns:
np.ndarray: target spacing
"""
base_target_spacing = self._target_spacing_base()
if mode == "3d" or mode == "2d":
target_spacing = base_target_spacing
else:
if not "lr" in mode:
raise RuntimeError(f"Mode {mode} is not supported for target spacing.")
downscale = int(mode.split('lr')[-1])
target_spacing = base_target_spacing * (2 ** downscale)
return target_spacing
def _target_spacing_base(self) -> np.ndarray:
"""
Determine target spacing.
Same as nnUNet v21
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
"""
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
target_size_mm = np.array(target) * np.array(target_size)
# we need to identify datasets for which a different target spacing could be beneficial. These datasets have
# the following properties:
# - one axis which much lower resolution than the others
# - the lowres axis has much less voxels than the others
# - (the size in mm of the lowres axis is also reduced)
worst_spacing_axis = np.argmax(target)
other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
other_spacings = [target[i] for i in other_axes]
other_sizes = [target_size[i] for i in other_axes]
has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * min(other_spacings))
has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
# we don't use the last one for now
# median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
if has_aniso_spacing and has_aniso_voxels:
spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
# don't let the spacing of that axis get higher than the other axes
if target_spacing_of_that_axis < min(other_spacings):
target_spacing_of_that_axis = max(min(other_spacings), target_spacing_of_that_axis) + 1e-5
target[worst_spacing_axis] = target_spacing_of_that_axis
return target
def trigger_low_res_model(
self,
prev_res_patch_size: Sequence[int],
transpose_forward: Sequence[int],
) -> bool:
"""
Trigger additional low resolution model
Args:
prev_res_patch_size: patch size of previous stage
Returns:
bool: If True, trigger a low resolution model. If False, current
resolution is ok.
"""
all_boxes = [case["boxes"] for case_id, case in \
self.data_properties["instance_props_per_patient"].items()]
all_boxes = concatenate_property_boxes(all_boxes)
object_size = np.percentile(box_size_np(all_boxes), 99.5, axis=0)
object_size = object_size[list(transpose_forward)]
if (np.asarray(prev_res_patch_size) < object_size).any():
return True
else:
return False
from nndet.planning.properties.instance import analyze_instances
from nndet.planning.properties.intensity import (
analyze_intensities,
get_modalities,
)
from nndet.planning.properties.medical import (
get_sizes_and_spacings_after_cropping,
get_size_reduction_by_cropping,
)
from nndet.planning.properties.segmentation import analyze_segmentations
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pickle
import numpy as np
from loguru import logger
from collections import OrderedDict, defaultdict
from multiprocessing.pool import Pool
from itertools import repeat
from typing import Dict, Sequence, List, Tuple
from nndet.io.load import load_case_cropped
from nndet.planning import DatasetAnalyzer
from nndet.core.boxes import box_iou_np
def analyze_instances(analyzer: DatasetAnalyzer) -> dict:
"""
Analyze instance segmentations
Args:
analyzer (DatasetAnalyzer): calling analyzer
Returns:
dict: extracted properties
"""
class_dct = analyzer.data_info["labels"]
all_classes = np.array([int(i) for i in class_dct.keys()])
if analyzer.overwrite or not analyzer.props_per_case_file.is_file():
props_per_case = run_analyze_instances(analyzer, all_classes)
else:
with open(analyzer.props_per_case_file, "rb") as f:
props_per_case = pickle.load(f)
output = {'class_dct': class_dct,
'all_classes': all_classes,
'instance_props_per_patient': props_per_case
}
output.update(analyze_instances_data_set(props_per_case))
return output
def run_analyze_instances(analyzer: DatasetAnalyzer,
all_classes: Sequence[int],
save: bool = True,
):
"""
Analyze all instance segmentation from data set
Args:
analyzer: calling analyzer
all_classes: all classes present in dataset
save: save properties per case as pickle
file :param:`analyzer.props_per_case_file`
Returns:
Dict: extract properties per case id [case_id, property_dict]
"""
props_per_case = OrderedDict()
with Pool(analyzer.num_processes) as p:
props = p.starmap(analyze_instances_per_case, zip(
repeat(analyzer), analyzer.case_ids, repeat(all_classes)))
# props = [analyze_instances_per_case(analyzer, cid, all_classes) for cid in analyzer.case_ids]
for case_id, prop in zip(analyzer.case_ids, props):
props_per_case[case_id] = prop
if save:
with open(analyzer.props_per_case_file, "wb") as f:
pickle.dump(props_per_case, f)
return props_per_case
def analyze_instances_data_set(props_per_case: OrderedDict) -> dict:
"""
Compute properties of instances over whole dataset
Args:
props_per_case: properties per case
`num_instances`: see :func:`count_instances`
`class_ious`: see :func:`instance_class_and_region_sizes`
`all_ious`: see :func:`instance_class_and_region_sizes`
Returns:
Dict: properties extracted from whole dataset
`num_instances`(Dict[int, int]): number of instances per class
`class_ious`(Dict[int, np.ndarray]): all flattened IoUs of
instances of the same class
`all_ious`(Dict[int, np.ndarray]): all flattened IoUs of
instances regardless their class
"""
data_props = {}
num_instances = defaultdict(int)
class_ious = defaultdict(list)
for case_id, case_props in props_per_case.items():
for cls, count in case_props["num_instances"].items():
num_instances[cls] += count
for cls, ious in case_props["class_ious"].items():
class_ious[cls].append(ious.flatten())
data_props["num_instances"] = num_instances
for cls in class_ious.keys():
class_ious[cls] = np.concatenate(class_ious[cls])
data_props["class_ious"] = class_ious
data_props["all_ious"] = np.concatenate([case_props["all_ious"].flatten()
for _, case_props in props_per_case.items()])
return data_props
def analyze_instances_per_case(analyzer: DatasetAnalyzer,
case_id: str,
all_classes: Sequence[int],
):
"""
Analyze a single case
Args:
analyzer: calling analyzer
case_id: case identifier
all_classes: all classes present in dataset
Returns:
Dict[str, Any]: properties extracted per case. See:
`num_instanes` (Dict[int, int]): number of instance per class
`has_classes` (Sequence[int]): classes present in this case
`volume_per_class(Dict[int, float])`: volume per class (sum of
all instance volume corresponding to class)
[all_classes, volume]
`region_volume_per_class`(Dict[int, List[float]]): volume of
each instance (sorted to corresponding class)
[all_classes, list(region_class_volume)]
`boxes`(np.ndarray): bounding boxes (x1, y1, x2, y2, (z1, z2))[N, dims * 2]
`all_ious`(np.ndarray): IoU values between all boxes independent of class
`class_ious`(Dict[int, np.ndarray]): IoU values of boxes with respect to classes
"""
logger.info(f"Processing instance properties of case {case_id}")
_, iseg, props = load_case_cropped(analyzer.cropped_data_dir, case_id)
props["num_instances"] = count_instances(props, all_classes)
props["has_classes"] = list(set(props["instances"].values()))
props["volume_per_class"], props["region_volume_per_class"] = \
instance_class_and_region_sizes(case_id, iseg, props, all_classes)
props["boxes"] = iseg_to_boxes(iseg)
props["all_ious"], props["class_ious"] = case_ious(props["boxes"], props)
return props
def count_instances(props: dict, all_classes: Sequence[int]) -> Dict[int, int]:
"""
Count instace classes inside one case
Args:
props: additional properties
`instances` (Dict[int, int]): maps each instance to a numerical class
all_classes: all classes in dataset
Returns:
Dict[int, int]: number of instance per class [all_classes, count]
"""
instance_classes = list(map(int, props["instances"].values()))
return {int(c): instance_classes.count(int(c)) for c in all_classes}
def instance_class_and_region_sizes(
case_id: str,
iseg: np.ndarray,
props: dict,
all_classes: Sequence[int],
) -> Tuple[
Dict[int, float], Dict[int, List[float]]]:
"""
Compute physical volume of all instances
Classes which are not present in case are 0 or an empty list.
Args:
iseg: instance segmentation
props: additional properties
`itk_spacing` (Sequence[float]): spacing information
`instances` (Dict[int, int]): maps each instance to a numerical class
all_classes: all classes in dataset
Returns:
Dict[int, float]: volume per class (sum of all instance volume
corresponding to class) [all_classes, volume]
Dict[int, List[float]]: volume of each instance (sorted to
corresponding class) [all_classes, list(region_class_volume)]
"""
vol_per_voxel = np.prod(props['itk_spacing'])
instance_classes = {int(key): int(item) for key, item in props["instances"].items()}
volume_per_class = OrderedDict(zip(all_classes, [0] * len(all_classes)))
region_volume_per_class = OrderedDict()
ids = np.unique(iseg)
ids = ids[ids > 0]
if len(ids) != len(list(instance_classes.keys())):
logger.warning(f"Instance lost. Found {instance_classes} in "
f"properties but {ids} in seg of {case_id}.")
volumer_per_instance = {c: np.sum(iseg == c) * vol_per_voxel for c in ids}
for instance_id, instance_vol in volumer_per_instance.items():
i_cls = instance_classes[instance_id]
volume_per_class[i_cls] += instance_vol
if i_cls in region_volume_per_class:
region_volume_per_class[i_cls].append(instance_vol)
else:
region_volume_per_class[i_cls] = [instance_vol]
return volume_per_class, region_volume_per_class
def iseg_to_boxes(iseg: np.ndarray) -> np.ndarray:
"""
Convert instance segmentations to bounding boxes
Args:
iseg: instance segmentation [dims] (NO channel dim)
Returns:
(np.ndarray): bounding boxes (x1, y1, x2, y2, (z1, z2))[N, dims * 2]
(order of boxes corresponds to instance ids)
Notes:
Please refer to `nndet.io.transforms.instances` for the function
and don't use this one.
"""
boxes = []
ids = np.unique(iseg)
ids = ids[ids > 0]
for instance_id in ids:
instance_idx = np.argwhere(iseg == instance_id)
coord_list = [np.min(instance_idx[:, 0]) - 1,
np.min(instance_idx[:, 1]) - 1,
np.max(instance_idx[:, 0]) + 1,
np.max(instance_idx[:, 1]) + 1]
if instance_idx.shape[1] == 3:
coord_list.extend([np.min(instance_idx[:, 2]) - 1,
np.max(instance_idx[:, 2]) + 1])
boxes.append(coord_list)
if boxes:
return np.stack(boxes)
else:
return []
def case_ious(boxes: np.ndarray, props: dict) -> Tuple[np.ndarray, Dict[int, np.ndarray]]:
"""
Compute IoU values for a single case (Evaluated both settings: all
bounding boxes and bounding boxes corresponding to a specific class)
Args:
boxes: bounding boxes of a single case (x1, y1, x2, y2, (z1, z2))[N, dims*2]
props: additional properties
`instances` (Dict[int, int]): maps each instance to a numerical class
Returns:
np.ndarray: IoU values of all bounding boxes
Dict[int, np.ndarray]: IoU values of bounding boxes which correspond
to a specific class
"""
if not isinstance(boxes, list):
all_ious = compute_each_iou(boxes)
class_ious = OrderedDict()
case_classes = list(set(map(int, props["instances"].values())))
case_instances = sorted(list(map(int, props["instances"].keys())))
for cls in case_classes:
cls_box_indices = [props["instances"][str(ci)] == cls for ci in case_instances]
class_ious[cls] = compute_each_iou(boxes[cls_box_indices])
else:
all_ious = np.array([])
class_ious = {}
return all_ious, class_ious
def compute_each_iou(boxes: np.ndarray):
"""
Compute IoU values from each box to each box (except the same box)
Args:
boxes: bounding boxes (x1, y1, x2, y2, (z1, z2))[N, dims*2]
Returns:
np.ndarray: computed IoUs [N-1, N-1]
"""
ious = box_iou_np(boxes, boxes)
# remove diagonal elements because they are always one
ious = ious[~np.eye(ious.shape[0], dtype=bool)].reshape(ious.shape[0], -1)
return ious
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pickle
import numpy as np
from loguru import logger
from itertools import repeat
from multiprocessing import Pool
from collections import OrderedDict
from typing import Union, Sequence, Dict
from nndet.planning.analyzer import DatasetAnalyzer
from nndet.io.load import load_case_cropped
def get_modalities(analyzer: DatasetAnalyzer) -> dict:
"""
Extract modalities from analyzer data info
Args:
analyzer: calling analyzer; need to provide `modalities` dict in :param:`data_info`
Returns:
dict: extract modalities
`modalities` (Dict[int, str]): modalities
"""
modalities = analyzer.data_info["modalities"]
modalities = {int(k): modalities[k] for k in modalities.keys()}
return {"modalities": modalities}
def analyze_intensities(analyzer: DatasetAnalyzer) -> dict:
"""
Either recompute or load intensity statistics from dataset
Args:
analyzer: calling analyer; need to provide a dictionary where
modalities are named in :param:`data_info` in key `modalities`
Returns:
Dict:
`intensity_properties`: result of :func:`run_collect_intensity_properties`
"""
num_modalities = len(analyzer.data_info["modalities"].keys())
if analyzer.overwrite or not analyzer.intensity_properties_file.is_file():
results = run_collect_intensity_properties(analyzer, num_modalities)
else:
with open(analyzer.intensity_properties_file, 'rb') as f:
results = pickle.load(f)
return {'intensity_properties': results}
def run_collect_intensity_properties(analyzer: DatasetAnalyzer,
num_modalities: int, save: bool = True) -> Dict[int, Dict]:
"""
Collect intensity properties over forground from whole dataset
Args:
analyzer: calling analyzer
num_modalities: number of modalities
save (optional): Save result in `analyzer.intensity_properties_file`. Defaults to True.
Returns:
Dict[int, Dict]: Intensity properties of foreground over the dataset.
Evaluated statistics: `median`; `mean`; `std`; `min`; `max`; `percentile_99_5`; `percentile_00_5`
`local_props`: contains a dict (with case ids) where statistics where computed per case
"""
with Pool(analyzer.num_processes) as p:
results = OrderedDict()
for mod_id in range(num_modalities):
logger.info(f"Processing intensity values of modality {mod_id}")
results[mod_id] = OrderedDict()
voxels = p.starmap(get_voxels_in_foreground,
zip(repeat(analyzer), analyzer.case_ids, repeat(mod_id)))
local_props = p.map(compute_stats, voxels)
props_per_case = OrderedDict()
for case_id, lp in zip(analyzer.case_ids, local_props):
props_per_case[case_id] = lp
all_voxels = []
for iv in voxels:
all_voxels += iv
results[mod_id]['local_props'] = props_per_case
results[mod_id].update(compute_stats(all_voxels))
if save:
with open(analyzer.intensity_properties_file, 'wb') as f:
pickle.dump(results, f)
return results
def get_voxels_in_foreground(analyzer: DatasetAnalyzer, case_id: str,
modality_id: int, subsample: int = 10) -> list:
"""
Get voxels from foreground
Args:
analyzer: calling analyzer
case_id: case identifier
modality_id: modality to choose for analyses
subsample (optional): Subsample voxels for computational purposes. Defaults to 10.
Returns:
list: foreground voxels
"""
data, seg, props = load_case_cropped(analyzer.cropped_data_dir, case_id)
modality = data[modality_id]
mask = seg > 0
voxels = list(modality[mask.astype(bool)][::subsample]) # no need to take every voxel
return voxels
def compute_stats(voxels: Union[Sequence, np.ndarray]):
"""
Compute statistics of voxels
Args:
voxels: input voxels
Returns:
Dict[str, np.ndarray]: computed statistics
`median`; `mean`; `std`; `min`; `max`; `percentile_99_5`; `percentile_00_5`
"""
if len(voxels) == 0:
stats = {"median": np.nan, "mean": np.nan, "std": np.nan, "min": np.nan,
"max": np.nan, "percentile_99_5": np.nan, "percentile_00_5": np.nan,
}
else:
stats = {
"median": np.median(voxels),
"mean": np.mean(voxels),
"std": np.std(voxels),
"min": np.min(voxels),
"max": np.max(voxels),
"percentile_99_5": np.percentile(voxels, 99.5),
"percentile_00_5": np.percentile(voxels, 00.5),
}
return stats
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from typing import Dict, List
from collections import defaultdict, OrderedDict
from nndet.io.load import load_properties_of_cropped
from nndet.planning.analyzer import DatasetAnalyzer
def get_sizes_and_spacings_after_cropping(analyzer: DatasetAnalyzer) -> Dict[str, List]:
"""
Load all sizes and spacings after cropping
Args:
analyzer: analyzer which calls this property
Returns:
Dict[str, List]: loaded sizes and spacings inside list
`all_sizes`: contains all sizes
`all_spacings`: contains all spacings
"""
output = defaultdict(list)
for case_id in analyzer.case_ids:
properties = load_properties_of_cropped(analyzer.cropped_data_dir / case_id)
output['all_sizes'].append(properties["size_after_cropping"])
output['all_spacings'].append(properties["original_spacing"])
return output
def get_size_reduction_by_cropping(analyzer: DatasetAnalyzer) -> Dict[str, Dict]:
"""
Compute all size reductions of each case
Args:
analyzer: analzer which calls this property
Returns:
Dict: computed size reductions
`size_reductions`: dictionary with each case id and reduction
"""
size_reduction = OrderedDict()
for case_id in analyzer.case_ids:
props = load_properties_of_cropped(analyzer.cropped_data_dir / case_id)
shape_before_crop = props["original_size_of_raw_data"]
shape_after_crop = props['size_after_cropping']
size_red = np.prod(shape_after_crop) / np.prod(shape_before_crop)
size_reduction[case_id] = size_red
return {"size_reductions": size_reduction}
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from nndet.planning.properties import (
get_sizes_and_spacings_after_cropping,
get_size_reduction_by_cropping,
get_modalities,
analyze_segmentations,
analyze_intensities,
analyze_instances,
)
def medical_segmentation_props(intensity_properties: bool = True):
"""
Default set for analysis of medical segmentation images
Args:
intensity_properties (optional): analyze intensity properties. Defaults to True.
Returns:
Sequence[Callable]: properties to calculate. Results can be summarized as follows:
See Also:
:func:`nndet.planning.medical.get_sizes_and_spacings_after_cropping`,
:func:`nndet.planning.medical.get_size_reduction_by_cropping`,
:func:`nndet.planning.intensity.get_modalities`,
:func:`nndet.planning.intensity.analyze_intensities`,
:func:`nndet.planning.segmentation.analyze_segmentations`,
"""
props = [
get_sizes_and_spacings_after_cropping,
get_size_reduction_by_cropping,
get_modalities,
analyze_segmentations,
]
if intensity_properties:
props.append(analyze_intensities)
else:
props.append(lambda x: {'intensity_properties': None})
return props
def medical_instance_props(intensity_properties: bool = True):
"""
Default set for analysis of medical instance segmentation images
Args:
intensity_properties (optional): analyze intensity properties. Defaults to True.
Returns:
Sequence[Callable]: properties to calculate. Results can be summarized as follows:
See Also:
:func:`nndet.planning.medical.get_sizes_and_spacings_after_cropping`,
:func:`nndet.planning.medical.get_size_reduction_by_cropping`,
:func:`nndet.planning.intensity.get_modalities`,
:func:`nndet.planning.intensity.analyze_intensities`,
:func:`nndet.planning.instance.analyze_instances`,
"""
props = [
get_sizes_and_spacings_after_cropping,
get_size_reduction_by_cropping,
get_modalities,
analyze_instances,
]
if intensity_properties:
props.append(analyze_intensities)
else:
props.append(lambda x: {'intensity_properties': None})
return props
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pickle
import numpy as np
from loguru import logger
from itertools import repeat
from collections import OrderedDict
from skimage.morphology import label
from multiprocessing import Pool
from typing import Dict, List, Sequence, Tuple, Callable
from nndet.planning.analyzer import DatasetAnalyzer
from nndet.io.load import load_case_cropped
def analyze_segmentations(analyzer: DatasetAnalyzer) -> dict:
"""
Analyze segmentation of dataset (if overwrite is disabled and analysis was already run,
this function will only load the results)
Args:
analyzer: analyzer which calls this function
Returns:
Dict:
`class_dct`(np.ndarray): contains all present classes
`all_classes`(np.ndarray): values of all foreground classes
`segmentation_props_per_patient`: result from :func:`run_analyze_segmentation`
"""
class_dct = analyzer.data_info["labels"]
all_classes = np.array([int(i) for i in class_dct.keys()])
if analyzer.overwrite or not analyzer.props_per_case_file.is_file():
props_per_case = run_analyze_segmentation(analyzer, all_classes)
else:
with open(analyzer.props_per_case_file, "rb") as f:
props_per_case = pickle.load(f)
return {'class_dct': class_dct, 'all_classes': all_classes,
'segmentation_props_per_patient': props_per_case}
def analyze_segmentation_per_case(analyzer: DatasetAnalyzer, case_id: str,
all_classes: Sequence[int]) -> Dict:
"""
1) what class is in this training case?
2) what is the size distribution for each class?
3) what is the region size of each class?
4) check if all in one region
Args:
analyzer: calling analyzer
case_id: case identifier
all_classes: all present classes in dataset
Returns:
Dict: region and class properties of case
`has_classes` (np.ndarray): present classes in case
`only_one_region` (Dict[Tuple[int], bool]):
contains information if individual classes are only present as a single region;
analyses if all classes build a single region;
can be indexed by the respective tuple of classes
`volume_per_class` ([Dict]): physical colume per class
`region_volume_per_class` (Dict[List]): physical volume per class per region
"""
logger.info(f"Processing segmentation properties of case {case_id}")
_, seg, props = load_case_cropped(analyzer.cropped_data_dir, case_id)
vol_per_voxel = np.prod(props['itk_spacing'])
unique_classes = np.unique(seg)
regions = [list(all_classes)]
for c in all_classes:
regions.append((c, ))
all_in_one_region = check_if_all_in_one_region(seg, regions)
volume_per_class, region_sizes = collect_class_and_region_sizes(
seg, all_classes, vol_per_voxel)
return {"has_classes": unique_classes, "only_one_region": all_in_one_region,
"volume_per_class": volume_per_class, "region_volume_per_class": region_sizes}
def run_analyze_segmentation(
analyzer: DatasetAnalyzer, all_classes: Sequence[int],
save: bool = True,
analyze_fn: Callable[[DatasetAnalyzer, str, Sequence[int]], Dict] = analyze_segmentation_per_case) \
-> Dict[str, Dict]:
"""
Analyze segmentations of all cases in analyzer
Args:
analyzer: analyzer which called this function
all_classes: values of all classes
save: Saves results as a file. Defaults to True.
(name is specified by `props_per_case_file` from analyzer)
analyze_fn: callable
to compute needed properties of a single segmentation case. Takes
the calling analyzer, the case id and a sequence of integers representing
all classes in the dataset and should return a single dict
Returns:
Dict[Dict]: computed properties per case
"""
props_per_case = OrderedDict()
with Pool(analyzer.num_processes) as p:
props = p.starmap(analyze_fn, zip(
repeat(analyzer), analyzer.case_ids, repeat(all_classes)))
for case_id, prop in zip(analyzer.case_ids, props):
props_per_case[case_id] = prop
if save:
with open(analyzer.props_per_case_file, "wb") as f:
pickle.dump(props_per_case, f)
return props_per_case
def check_if_all_in_one_region(seg: np.ndarray,
regions: Sequence[Sequence[int]]) -> Dict[Tuple[int], bool]:
"""
Check if regions are splited over multiple instances or are all connected
Args:
seg: segmentation
regions: Sequence of multiple regions to analyze.
Each region can contain multiple classes
Returns:
Dict[Tuple[int], bool]: result for each region
"""
res = OrderedDict()
for r in regions:
new_seg = np.zeros(seg.shape)
for c in r:
new_seg[seg == c] = 1
labelmap, numlabels = label(new_seg, return_num=True)
if numlabels != 1:
res[tuple(r)] = False
else:
res[tuple(r)] = True
return res
def collect_class_and_region_sizes(seg: np.ndarray, all_classes: Sequence[int],
vol_per_voxel: float) -> (Dict, Dict[str, Dict]):
"""
Collect class and region sizes from segmentation
Args:
seg: segmentation
all_classes: array with all classes
vol_per_voxel: physical volume per voxel
Returns:
Dict: volume per class (dict index corresponds to class)
Dict[List]: sizes of each region;
first dict indexes thes class while second dict indexed the regions
"""
volume_per_class = OrderedDict()
region_volume_per_class = OrderedDict()
for c in all_classes:
volume_per_class[c] = np.sum(seg == c) * vol_per_voxel
region_volume_per_class[c] = []
labelmap, numregions = label(seg == c, return_num=True)
for l in range(1, numregions + 1):
region_volume_per_class[c].append(np.sum(labelmap == l) * vol_per_voxel)
return volume_per_class, region_volume_per_class
from nndet.preprocessing.crop import ImageCropper
from nndet.preprocessing.preprocessor import (
PreprocessorType,
AbstractPreprocessor,
GenericPreprocessor,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import shutil
import pickle
import numpy as np
from loguru import logger
from multiprocessing.pool import Pool
from pathlib import Path
from typing import List, Tuple, Sequence
from scipy.ndimage import binary_fill_holes
from nndet.io.paths import get_case_id_from_path
from nndet.io.load import load_case_from_list
def create_nonzero_mask(data: np.ndarray) -> np.ndarray:
"""
Create a nonzero mask from data
Args:
data (np.ndarray): input data [C, X, Y, Z]
Returns:
np.ndarray: binary mask on nonzero regions [X, Y, Z]
"""
assert len(data.shape) == 4 or len(data.shape) == 3, \
"data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.max(data != 0, axis=0)
nonzero_mask = binary_fill_holes(nonzero_mask.astype(bool))
return nonzero_mask
def get_bbox_from_mask(mask: np.ndarray, outside_value: int = 0) -> List[Tuple]:
"""
Create a bounding box from a mask
Args:
mask (np.ndarray): mask [X, Y, Z]
outside_value (int): background value
Returns:
np.ndarray: [(dim0_min, dim0_max), (dim1_min, dim1_max), (dim2_min, dim2_max))
"""
mask_voxel_coords = (mask != outside_value).nonzero()
min0idx = int(np.min(mask_voxel_coords[0]))
max0idx = int(np.max(mask_voxel_coords[0])) + 1
min1idx = int(np.min(mask_voxel_coords[1]))
max1idx = int(np.max(mask_voxel_coords[1])) + 1
idx = [(min0idx, max0idx), (min1idx, max1idx)]
if len(mask_voxel_coords) == 3:
min2idx = int(np.min(mask_voxel_coords[2]))
max2idx = int(np.max(mask_voxel_coords[2])) + 1
idx.append((min2idx, max2idx))
return idx
def crop_to_bbox_no_channels(image, bbox: Sequence[Sequence[int]]):
"""
Crops image to bounding box (in spatial dimensions)
Args:
image (arraylike): 2d or 3d array
bbox (Sequence[Sequence[int]]): bounding box coordinated in an interleaved fashion
(e.g. (x1, x2), (y1, y2), (z1, z2))
Returns:
arraylike: cropped array
"""
resizer = tuple([slice(_dim[0], _dim[1]) for _dim in bbox])
return image[resizer]
def crop_to_bbox(data: np.ndarray, bbox: Sequence[Sequence[int]]):
"""
Crops image to bounding box (performed per channel)
Args:
data (np.ndarray): 3d or 4d array [C, X, Y, (Z)]
bbox (Sequence[Sequence[int]]): bounding box coordinated in an interleaved fashion
(e.g. (x1, x2), (y1, y2), (z1, z2))
Returns:
np.ndarray: cropped array
"""
cropped_data = []
for c in range(data.shape[0]):
cropped = crop_to_bbox_no_channels(data[c], bbox)
cropped_data.append(cropped)
data = np.stack(cropped_data)
return data
def crop_to_nonzero(data, seg=None, nonzero_label=-1):
"""
Crop data to nonzero region of data
Args:
data (np.ndarray): data to crop
seg (np.ndarray): segmentation
nonzero_label (int): nonzero label is written into segmentation map
where only background was found
Returns:
np.ndarray: cropped data
np.ndarray: cropped and filled (with nonzero_label) segmentation
List[Tuple[int]]: bounding box of nonzero region
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask, 0)
data = crop_to_bbox(data, bbox)
seg = crop_to_bbox(seg, bbox)
nonzero_mask = crop_to_bbox_no_channels(nonzero_mask, bbox)[None]
if seg is not None:
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(np.int32)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask
return data, seg, bbox
class ImageCropper(object):
def __init__(self, num_processes: int, output_dir: Path = None):
"""
Helper class to crop images to non zero region (must hold for all modalities)
In the case of BRaTS and ISLES data this results in a significant reduction in image size
Args:
num_processes (int): number of processes to use for cropping
output_dir (Path): path to output directory
"""
self.output_dir = Path(output_dir) if output_dir is not None else None
self.num_processes = num_processes
self.maybe_init_output_dir()
def maybe_init_output_dir(self):
if self.output_dir is not None:
(self.output_dir / "imagesTr").mkdir(parents=True, exist_ok=True)
(self.output_dir / "labelsTr").mkdir(parents=True, exist_ok=True)
def run_cropping(self,
case_files: List[List[Path]],
overwrite_existing: bool = False,
output_dir: Path = None,
copy_gt_data: bool = True,
):
"""
Crops data to non zero region and saves them into output_dir
Optional: also copies ground truth data
Args:
case_files (List[List[Path]]): list with all cases in the structure [Case[Case Files]];
where case files are sorted to corresponding modalities (last file is the label file)
overwrite_existing (bool): overwrite existing crops
output_dir (Path): path to output directory
copy_gt_data (bool): copies ground truth data to output directory
"""
if output_dir is not None:
self.output_dir = Path(output_dir)
self.maybe_init_output_dir()
if copy_gt_data:
self.copy_gt_data(case_files)
list_of_args = []
for _i, case in enumerate(case_files):
case_id = get_case_id_from_path(str(case[0]))
assert not case_id.endswith(".gz") and not case_id.endswith(".nii")
list_of_args.append((case, case_id, overwrite_existing))
if self.num_processes == 0:
for a in list_of_args:
self.process_data(*a)
else:
with Pool(processes=self.num_processes) as p:
p.starmap(self.process_data, list_of_args)
def copy_gt_data(self, case_files: List[List[Path]]):
"""
Copy ground truth to output directory
"""
output_dir_gt = self.output_dir / "labelsTr"
if output_dir_gt.is_dir():
shutil.rmtree(output_dir_gt)
source_dir_gt = case_files[0][-1].parent
shutil.copytree(source_dir_gt, output_dir_gt)
def process_data(self, case: List[Path], case_id: str, overwrite_existing: bool = False):
"""
Extract nonzero region from all cases and create a single array where segmentation
is located in the last channel and save as npz (saved in key `data`)
Additional properties per case are saved inside a pkl file
Args:
case (List[Path]): list of paths to data and label (label is always at the last position
and data is sorted after modalities)
case_id (str): case identifier
overwrite_existing (bool): overwrite existing data
"""
try:
logger.info(f"Processing case {case_id}")
npz_exists = (self.output_dir / "imagesTr" / f"{case_id}.npz").is_file()
pkl_exists = (self.output_dir / "imagesTr" / f"{case_id}.pkl").is_file()
if not (npz_exists and pkl_exists) or overwrite_existing:
data, seg, properties = self.load_crop_from_list_of_files(case[:-1], case[-1])
all_data = np.vstack((data, seg))
np.savez_compressed(self.output_dir / "imagesTr" / f"{case_id}.npz", data=all_data)
with open(self.output_dir / "imagesTr" / f"{case_id}.pkl", 'wb') as f:
pickle.dump(properties, f)
else:
logger.warning(f"Case {case_id} already exists and overwrite is deactivated")
except Exception as e:
logger.info(f"exception in: {case_id}: {e}")
raise e
@staticmethod
def load_crop_from_list_of_files(data_files: List[Path], seg_file: Path = None):
"""
Load and crop form list of files
Args:
data_files (List[Path]): paths to data files
seg_file (Path): pth to segmentation
Returns:
np.ndarray: cropped data
np.ndarray: cropped (and filled segmentation: -1 where no forground exists) label
dict: additional properties
`original_size_of_raw_data`: original shape of data (correctly reordered)
`original_spacing`: original spacing (correctly reordered)
`list_of_data_files`: paths of data files
`seg_file`: path to label file
`itk_origin`: origin in world coordinates
`itk_spacing`: spacing in world coordinates
`itk_direction`: direction in world coordinates
`crop_bbox`: List[Tuple[int]] cropped bounding box
`classes`: present classes in segmentation
`size_after_cropping`: size after cropping
"""
data, seg, properties = load_case_from_list(data_files, seg_file)
return ImageCropper.crop(data, properties, seg)
@staticmethod
def crop(data: np.ndarray, properties: dict, seg: np.ndarray = None):
"""
Crop data and segmentation to non zero region
Args:
data (np.ndarray): data to crop [C, X, Y, Z]
properties (dict): additional properties
seg (np.ndarray): segmentation [1, X, Y, Z]
Returns:
data (np.ndarray): data to crop [C, X, Y, Z]
seg (np.ndarray): segmentation [1, X, Y, Z]
properties (dict): newly added properties
`crop_bbox`: List[Tuple[int]] cropped bounding box
`classes`: present classes in segmentation
`size_after_cropping`: size after cropping
"""
shape_before = data.shape
data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=-1)
shape_after = data.shape
logger.info(f"Shape before crop {shape_before}; after crop {shape_after}; "
f"spacing {np.array(properties['original_spacing'])}")
properties["crop_bbox"] = bbox
properties['classes'] = np.unique(seg)
seg[seg < -1] = 0
properties["size_after_cropping"] = data[0].shape
return data, seg, properties
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from os import PathLike
from loguru import logger
from abc import ABC, abstractmethod
from multiprocessing import Pool
from pathlib import Path
from typing import Dict, Sequence, List, Tuple, TypeVar, Union
from itertools import repeat
from nndet.io.transforms.instances import instances_to_boxes_np
from nndet.io.paths import get_case_ids_from_dir, get_case_id_from_path
from nndet.io.load import load_case_cropped, save_pickle
from nndet.preprocessing.resampling import resample_patient
from nndet.io.crop import ImageCropper
class AbstractPreprocessor(ABC):
DATA_ID = "abstractdata"
def __init__(self, **kwargs):
"""
Interface for preprocessor
"""
for key, item in kwargs.items():
setattr(self, key, item)
@abstractmethod
def run(self,
target_spacings: Sequence[Sequence[float]],
identifiers: Sequence[str],
cropped_data_dir: Path,
preprocessed_output_dir: Path,
num_processes: int,
force_separate_z=None,
):
"""
Run preprocessing
Args:
target_spacings: target spacing for each case
identifiers: identifier strings used to name the directory
cropped_data_dir: source directory
preprocessed_output_dir: target directory
num_processes: number of processes used for preprocessing
force_separate_z: force independent resampling of z direction
"""
raise NotImplementedError
@abstractmethod
def run_test(self,
data_files,
target_spacing,
target_dir: PathLike,
) -> None:
"""
Preprocess and save test data
Args:
data_files: path to data files
target_spacing: spacing to resample
target_dir: directory to save data to
"""
raise NotImplementedError
@abstractmethod
def preprocess_test_case(self,
data_files,
target_spacing,
seg_file=None,
force_separate_z=None,
) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Preprocess a test file
Args:
data_files: path to data files
target_spacing: spacing to resample
seg_file: optional segmentation file
force_separate_z: separate resampling in z direction
Returns:
np.ndarray: preprocessed data [C, dims]
np.ndarray: preprocessed segmentation [1, dims]
dict: updated properties
"""
raise NotImplementedError
class GenericPreprocessor:
DATA_ID = "Generic"
def __init__(self,
norm_scheme_per_modality: Dict[int, str],
use_mask_for_norm: Dict[int, bool],
transpose_forward: Sequence[int],
intensity_properties: Dict[int, Dict] = None,
resample_anisotropy_threshold: float = 3.,
):
"""
Preprocess data
Args:
norm_scheme_per_modality: integer index represents modality and string is
either `CT`, `CT2`, 'BValRaw'. Other modalities are treated the with zeo mean and unit std.
use_mask_for_norm: only foreground values should be used for normalization
(defined for each modality)
transpose_forward: transpose input data
intensity_properties: Intensity properties of foreground over the dataset.
Evaluated statistics: `median`; `mean`; `std`; `min`; `max`;
`percentile_99_5`; `percentile_00_5`
`local_props`: contains a dict (with case ids) where statistics
where computed per case
Overwrites:
:self:`data_id`: unique identifier of GenericPreprocessor
"""
self.resample_anisotropy_threshold = resample_anisotropy_threshold
self.intensity_properties = intensity_properties
self.transpose_forward = list(transpose_forward)
self.use_mask_for_norm = use_mask_for_norm
self.norm_scheme_per_modality = norm_scheme_per_modality
self.norm_schemes = {
"CT": self.normalize_ct,
"CT2": self.normalize_ct2,
"CT3": self.normalize_ct,
"raw": self.no_norm,
}
def run(self,
target_spacings: Sequence[Sequence[float]],
identifiers: Sequence[str],
cropped_data_dir: Path,
preprocessed_output_dir: Path,
num_processes: Union[int, Sequence[int]],
overwrite: bool = False,
):
"""
Run preprocessing
Args:
target_spacings: target spacing for each case
identifiers: identifier strings used to name the directory
cropped_data_dir: source directory
preprocessed_output_dir: target directory
num_processes: number of processes used for preprocessing
overwrite: overwrite existing data
"""
case_ids, num_processes = self.initialize_run(
target_spacings=target_spacings,
cropped_data_dir=cropped_data_dir,
preprocessed_output_dir=preprocessed_output_dir,
num_processes=num_processes,
)
for identifier, spacing, nump in zip(identifiers, target_spacings, num_processes):
logger.info(f"+++ Preprocessing {identifier} +++")
output_dir_stage = preprocessed_output_dir / identifier / "imagesTr"
output_dir_stage.mkdir(parents=True, exist_ok=True)
if not overwrite:
case_ids_npz_present = get_case_ids_from_dir(
output_dir_stage, remove_modality=False, pattern="*.npz")
case_ids_pkl_present = get_case_ids_from_dir(
output_dir_stage, remove_modality=False, pattern="*.pkl")
case_ids_present = list(set.intersection(set(case_ids_npz_present), set(case_ids_pkl_present)))
logger.info(f"Skipping case ids which are already present {case_ids_present}")
_case_ids = list(filter(lambda x: x not in case_ids_present, case_ids))
else:
_case_ids = case_ids
logger.info(f"Running preprocessing on {_case_ids}")
if nump == 0:
for _cid in _case_ids:
self.run_process(spacing, _cid, output_dir_stage, cropped_data_dir)
else:
with Pool(processes=nump) as p:
p.starmap(self.run_process,
zip(repeat(spacing),
_case_ids,
repeat(output_dir_stage),
repeat(cropped_data_dir),
))
def initialize_run(self,
target_spacings: Sequence[Sequence[float]],
cropped_data_dir: Path,
preprocessed_output_dir: Path,
num_processes: int,
) -> Tuple[List[str], List[int]]:
"""
Prepare preprocessing run
Args:
target_spacings: target spacings
cropped_data_dir: source dir
preprocessed_output_dir: target dir
num_processes: number of processes
Returns:
List[str]: case ids from source dir
List[int]: number of processes for each stage
"""
logger.info("Initializing preprocessing")
logger.info(f"Folder with cropped data: {cropped_data_dir}")
logger.info(f"Folder for preprocessed data:{preprocessed_output_dir}")
for key, modality in self.norm_scheme_per_modality.items():
if modality in self.norm_schemes:
logger.info(f"Found normalization scheme for {modality}")
else:
logger.info(f"No normalization scheme for {modality} using zero mean unit std.")
preprocessed_output_dir.mkdir(parents=True, exist_ok=True)
num_stages = len(target_spacings)
if not isinstance(num_processes, Sequence):
num_processes = [num_processes] * num_stages
assert len(num_processes) == num_stages
case_ids = get_case_ids_from_dir(
cropped_data_dir, pattern="*.npz", remove_modality=False)
return case_ids, num_processes
def run_process(self,
target_spacing: Sequence[float],
case_id: str,
output_dir_stage: Path,
cropped_data_dir: Path,
) -> None:
"""
Process a single case
Result is saved into :param:`output_dir_stage`
Args:
target_spacing: target spacing for processed case
case_id: case identifier
output_dir_stage: path to output directory
cropped_data_dir: path to source directory
"""
data, seg, properties = load_case_cropped(cropped_data_dir, case_id)
seg = seg[None]
data, seg, properties = self.apply_process(
data, target_spacing, properties, seg)
properties["use_nonzero_mask_for_norm"] = self.use_mask_for_norm
data = data.astype(np.float32)
seg = seg.astype(np.int32)
candidates = self.compute_candidates(
data=data,
seg=seg,
properties=properties,
)
logger.info(f"Saving: {case_id} into {output_dir_stage}.")
np.savez_compressed(str(output_dir_stage / f"{case_id}.npz"),
data=data,
seg=seg,
)
save_pickle(candidates, output_dir_stage / f"{case_id}_boxes.pkl")
save_pickle(properties, output_dir_stage / f"{case_id}.pkl")
def apply_process(self,
data: np.ndarray,
target_spacing: Sequence[float],
properties: dict,
seg: np.ndarray = None,
) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Applies all preprocessing steps to data and segmentation
Args:
data: input data
target_spacing: target spacing (not! transposed)
properties: dict with properties for preprocessing
seg: input segmentation
Returns:
np.ndarray: preprocessed data
np.ndarray: preprocessed segmentation
dict: updated properties
"""
data, seg, original_spacing, target_spacing, before = self.transpose(
data, seg, properties["original_spacing"], target_spacing)
data, seg, after = self.resample(
data, seg, original_spacing, target_spacing)
# logger.info(f"\nBefore: {before} \nAfter: {after}\n")
if seg is not None:
seg[seg < -1] = 0
properties["size_after_resampling"] = data[0].shape
properties["spacing_after_resampling"] = after["spacing"]
data = self.normalize(data, seg)
return data, seg, properties
def transpose(self,
data: np.ndarray,
seg: np.ndarray,
original_spacing: Sequence[float],
target_spacing: Sequence[float]) -> Tuple[
np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
"""
Transpose data, segmentation and spacings
Args:
data: input data
seg: input segmentation
original_spacing: original spacing
target_spacing: target spacing
Returns:
np.ndarray: transposed data
np.ndarray: transposed segmentation
np.ndarray: transposed original spacing
np.ndarray: transposed target spacing
dict: values for debugging
"""
data = data.transpose((0, *[i + 1 for i in self.transpose_forward]))
seg = seg.transpose((0, *[i + 1 for i in self.transpose_forward]))
_original_spacing = np.array(original_spacing)[self.transpose_forward]
_target_spacing = np.array(target_spacing)[self.transpose_forward]
before = {
"spacing": original_spacing,
"transpose": self.transpose_forward,
"spacing_transposed": _original_spacing,
"shape (transposed": data.shape,
}
return data, seg, _original_spacing, _target_spacing, before
def resample(self,
data: np.ndarray,
seg: np.ndarray,
original_spacing: Sequence[float],
target_spacing: Sequence[float],
) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Resample data and segmentation to new spacing
Args:
data: input data
seg: input segmentation
original_spacing: original spacing
target_spacing: target spacing
Returns:
np.ndarray: resampled data
np.ndarray: resampled segmentation
dict: properties after resampling
`spacing`: spacing after resampling
`shape (resampled)`: shape after resampling
"""
original_spacing = np.array(original_spacing)
target_spacing = np.array(target_spacing)
data[np.isnan(data)] = 0
data, seg = resample_patient(data,
seg,
original_spacing,
target_spacing,
order_data=3,
order_seg=0,
force_separate_z=False,
order_z_data=9999,
order_z_seg=9999,
separate_z_anisotropy_threshold=
self.resample_anisotropy_threshold,
)
after = {
"spacing": target_spacing,
"shape (resampled)": data.shape,
}
return data, seg, after
def normalize(self, data: np.ndarray, seg: np.ndarray) -> np.ndarray:
"""
Normalize data with correct scheme
Args:
data: input data
seg: input data
Returns:
np.ndarray: normalized data
"""
assert len(self.norm_scheme_per_modality) == len(data), \
f"norm_scheme_per_modality must have as many entries as data has modalities"
assert len(self.use_mask_for_norm) == len(data), \
f"use_mask_for_norm must have as many entries as data has modalities"
for c in range(len(data)):
scheme = self.norm_scheme_per_modality[c]
scheme_fn = self.norm_schemes.get(scheme, self.normalize_other)
data = scheme_fn(data, seg, c, self.use_mask_for_norm)
return data
def normalize_ct(self,
data: np.ndarray,
seg: np.ndarray,
modality: int,
use_nonzero_mask: Dict[int, bool],
) -> np.ndarray:
"""
clip to lb and ub from train data foreground and use foreground mn and sd from training data
(This uses the foreground mean and std!)
Args:
data: data to normalize [C, dims]
seg: segmentation [C, dims]
modality: current modality
use_nonzero_mask: use non zero region for normalization and set all values
outside to zero [C]
Returns:
np.ndarray: normalized data (only modality channel was changes)
"""
assert self.intensity_properties is not None, \
"ERROR: if there is a CT then we need intensity properties"
mean_intensity = self.intensity_properties[modality]['mean']
std_intensity = self.intensity_properties[modality]['std']
lower_bound = self.intensity_properties[modality]['percentile_00_5']
upper_bound = self.intensity_properties[modality]['percentile_99_5']
data[modality] = np.clip(data[modality], lower_bound, upper_bound)
data[modality] = (data[modality] - mean_intensity) / std_intensity
if use_nonzero_mask[modality]:
data[modality][seg[-1] < 0] = 0
return data
def normalize_ct2(self, data: np.ndarray, seg: np.ndarray, modality: int,
use_nonzero_mask: Dict[int, bool]) -> np.ndarray:
"""
clip to lb and ub from train data foreground, use mn and sd from each case for normalization
(This uses mean and std from whole case!)
Args:
data: data to normalize [C, dims]
seg: segmentation [C, dims]
modality: current modality
use_nonzero_mask: use non zero region for normalization [C]
Returns:
np.ndarray: normalized data (only modality channel was changes)
"""
assert self.intensity_properties is not None, \
"ERROR: if there is a CT then we need intensity properties"
lower_bound = self.intensity_properties[modality]['percentile_00_5']
upper_bound = self.intensity_properties[modality]['percentile_99_5']
mask = (data[modality] > lower_bound) & (data[modality] < upper_bound)
data[modality] = np.clip(data[modality], lower_bound, upper_bound)
mn = data[modality][mask].mean()
sd = data[modality][mask].std()
data[modality] = (data[modality] - mn) / sd
if use_nonzero_mask[modality]:
data[modality][seg[-1] < 0] = 0
return data
def normalize_ct3(self,
data: np.ndarray,
seg: np.ndarray,
modality: int,
use_nonzero_mask: Dict[int, bool],
) -> np.ndarray:
"""
clip to lb and ub from train data foreground and use foreground mn
and sd from training data (This uses the foreground mean and std!)
Use this if channels are overloaded with spatial information (
in case of CT)
Args:
data: data to normalize [C, dims]
seg: segmentation [C, dims]
modality: current modality
use_nonzero_mask: use non zero region for normalization and set all values
outside to zero [C]
Returns:
np.ndarray: normalized data (only modality channel was changes)
"""
assert self.intensity_properties is not None, \
"ERROR: if there is a CT then we need intensity properties"
mean_intensity = np.mean([k["mean"] for k in self.intensity_properties.values()])
# the intensity values are not independent but we do not have enough information here
std_intensity = np.sqrt(np.sum([k["std"] ** 2 for k in self.intensity_properties.values()]))
lower_bound = np.mean([k["percentile_00_5"] for k in self.intensity_properties.values()])
upper_bound = np.mean([k["percentile_99_5"] for k in self.intensity_properties.values()])
data[modality] = np.clip(data[modality], lower_bound, upper_bound)
data[modality] = (data[modality] - mean_intensity) / std_intensity
if use_nonzero_mask[modality]:
data[modality][seg[-1] < 0] = 0
return data
def normalize_other(self, data: np.ndarray, seg: np.ndarray, modality: int,
use_nonzero_mask: Dict[int, bool]) -> np.ndarray:
"""
Zero mean and unit std
Args:
data: data to normalize [C, dims]
seg: segmentation [C, dims]
modality: current modality
use_nonzero_mask: use non zero region for normalization [C]
Returns:
np.ndarray: normalized data (only modality channel was changes)
"""
if use_nonzero_mask[modality]:
mask = seg[-1] >= 0
else:
mask = np.ones(seg.shape[1:], dtype=bool)
data[modality][mask] = (data[modality][mask] - data[modality][mask].mean()) / \
(data[modality][mask].std() + 1e-8)
data[modality][mask == 0] = 0
return data
def no_norm(self, data: np.ndarray, seg: np.ndarray, modality: int,
use_nonzero_mask: Dict[int, bool]) -> np.ndarray:
"""
No normalization only masking
Args:
data: data to normalize [C, dims]
seg: segmentation [C, dims]
modality: current modality
use_nonzero_mask: use non zero region for normalization [C]
Returns:
np.ndarray: masked data (only modality channel was changed)
"""
if use_nonzero_mask[modality]:
mask = seg[-1] >= 0
else:
mask = np.ones(seg.shape[1:], dtype=bool)
data[modality][mask == 0] = 0
return data
@staticmethod
def compute_candidates(
data: np.ndarray,
seg: np.ndarray,
properties: dict,
) -> dict:
"""
Precompute candidate sampling positions for training
This method computes the bounding boxes of each present
instance which can be used to oversample foreground effectively.
Args:
data: data after resampling
seg: instance segmentation after resampling
"""
dim = data.ndim - 1
boxes = instances_to_boxes_np(seg[0], dim=dim)[0]
instances = np.unique(seg)
instances = instances[instances > 0].astype(np.int32) # [N]
instances = instances.tolist()
instances_props = properties["instances"]
labels = [int(instances_props[str(i)]) for i in instances]
assert (len(boxes) == len(instances)) or ((boxes.size == 0) and (len(instances) == 0))
assert len(labels) == len(instances)
return {
"boxes": boxes,
"instances": instances,
"labels": labels,
}
def run_test(self,
data_files,
target_spacing,
target_dir: PathLike,
) -> None:
"""
Preprocess and save test data
Args:
data_files: path to data files
target_spacing: spacing to resample
target_dir: directory to save data to
"""
target_dir = Path(target_dir)
data, seg, properties = self.preprocess_test_case(
data_files=data_files,
target_spacing=target_spacing,
)
case_id = get_case_id_from_path(str(data_files[0]), remove_modality=True)
np.savez_compressed(str(target_dir / f"{case_id}.npz"), data=data)
save_pickle(properties, target_dir / f"{case_id}")
def preprocess_test_case(self,
data_files,
target_spacing,
seg_file=None,
) -> Tuple[np.ndarray, np.ndarray, dict]:
"""
Preprocess a test file
Args:
data_files: path to data files
target_spacing: spacing to resample
seg_file: optional segmentation file
Returns:
np.ndarray: preprocessed data
np.ndarray: preprocessed segmentation
dict: updated properties
"""
data, seg, properties = ImageCropper.load_crop_from_list_of_files(
data_files, seg_file)
data, seg, properties = self.apply_process(
data=data,
target_spacing=target_spacing,
properties=properties,
seg=seg,
)
return data.astype(np.float32), seg.astype(np.int32), properties
PreprocessorType = TypeVar('PreprocessorType', bound=AbstractPreprocessor)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import nnunet.preprocessing.preprocessing as nn_preprocessing
def resize_segmentation(segmentation, new_shape, order=3, cval=0):
"""
Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
hot encoding which is resized and transformed back to a segmentation map.
This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
"""
return nn_preprocessing.resize_segmentation(
segmentation=segmentation, new_shape=new_shape, order=order, cval=cval)
def get_do_separate_z(spacing, anisotropy_threshold: float = 3):
return nn_preprocessing.get_do_separate_z(spacing=spacing, anisotropy_threshold=anisotropy_threshold)
def get_lowres_axis(new_spacing):
return nn_preprocessing.get_lowres_axis(new_spacing=new_spacing)
def resample_patient(data,
seg,
original_spacing,
target_spacing,
order_data=3,
order_seg=0,
force_separate_z=False,
cval_data=0,
cval_seg=-1,
order_z_data=0,
order_z_seg=0,
separate_z_anisotropy_threshold: float = 3,
):
return nn_preprocessing.resample_patient(data=data, seg=seg, original_spacing=original_spacing,
target_spacing=target_spacing, order_data=order_data,
order_seg=order_seg, force_separate_z=force_separate_z,
cval_data=cval_data, cval_seg=cval_seg, order_z_data=order_z_data,
order_z_seg=order_z_seg,
separate_z_anisotropy_threshold=separate_z_anisotropy_threshold)
def resample_data_or_seg(data, new_shape, is_seg, axis=None, order=3,
do_separate_z=False, cval=0, order_z=0) -> np.ndarray:
"""
Resample data or segmentation
Args:
data: array to resample [C, dims]
new_shape: define new dims (without channels)
is_seg: changes the resampling strategy
axis: anisotropic axis, different resampling order used here
order: order of resampling along the isotropic axis
do_separate_z: Different resampling along z dimensions
cval: //
order_z: if separate z resampling is done then this is the order for resampling in z
Returns:
np.ndarray: resampled array
"""
return nn_preprocessing.resample_data_or_seg(
data=data, new_shape=new_shape, is_seg=is_seg, axis=axis,
order=order, do_separate_z=do_separate_z, cval=cval, order_z=order_z)
from typing import Mapping, Type
from nndet.utils.registry import Registry
from nndet.ptmodule.base_module import LightningBaseModule
MODULE_REGISTRY: Mapping[str, Type[LightningBaseModule]] = Registry()
# register modules
from nndet.ptmodule.retinaunet import *
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations
import os
from time import time
from typing import Any, Callable, Dict, Optional, Sequence, Hashable, Type, TypeVar
import torch
import pytorch_lightning as pl
from pytorch_lightning.core.memory import ModelSummary
from loguru import logger
from nndet.io.load import save_txt
from nndet.inference.predictor import Predictor
class LightningBaseModule(pl.LightningModule):
def __init__(self,
model_cfg: dict,
trainer_cfg: dict,
plan: dict,
**kwargs
):
"""
Provides a base module which is used inside of nnDetection.
All lightning modules of nnDetection should be derifed from this!
Args:
model_cfg: model configuration. Check :method:`from_config_plan`
for more information
trainer_cfg: trainer information
plan: contains parameters which were derived from the planning
stage
"""
super().__init__()
self.model_cfg = model_cfg
self.trainer_cfg = trainer_cfg
self.plan = plan
self.model = self.from_config_plan(
model_cfg=self.model_cfg,
plan_arch=self.plan["architecture"],
plan_anchors=self.plan["anchors"],
)
self.example_input_array_shape = (
1, plan["architecture"]["in_channels"], *plan["patch_size"],
)
self.epoch_start_tic = 0
self.epoch_end_toc = 0
@property
def max_epochs(self):
"""
Number of epochs to train
"""
return self.trainer_cfg["max_num_epochs"]
def on_epoch_start(self) -> None:
"""
Save time
"""
self.epoch_start_tic = time()
return super().on_epoch_start()
def validation_epoch_end(self, validation_step_outputs):
"""
Print time of epoch
(needed for cluster where progress bar is deactivated)
"""
self.epoch_end_toc = time()
logger.info(f"This epoch took {int(self.epoch_end_toc - self.epoch_start_tic)} s")
return super().validation_epoch_end(validation_step_outputs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Used to generate summary
Do not(!) use this for inference. This will only forward
the input through the network which does not include
detection spcific postprocessing!
"""
return self.model(x)
@property
def example_input_array(self):
"""
Create example input
"""
return torch.zeros(*self.example_input_array_shape)
def summarize(self, mode: Optional[str]) -> Optional[ModelSummary]:
"""
Save model summary as txt
"""
summary = super().summarize(mode=mode)
save_txt(summary, "./network")
return summary
def inference_step(self, batch: Any, **kwargs) -> Dict[str, Any]:
"""
Prediction method used by nnDetection predictor class
"""
return self.model.inference_step(batch, **kwargs)
@classmethod
def from_config_plan(cls,
model_cfg: dict,
plan_arch: dict,
plan_anchors: dict,
log_num_anchors: str = None,
**kwargs,
):
"""
Used to generate the model
"""
raise NotImplementedError
@staticmethod
def get_ensembler_cls(key: Hashable, dim: int) -> Callable:
"""
Get ensembler classes to combine multiple predictions
Needs to be overwritten in subclasses!
"""
raise NotImplementedError
@classmethod
def get_predictor(cls,
plan: Dict,
models: Sequence[LightningBaseModule],
num_tta_transforms: int = None,
**kwargs
) -> Type[Predictor]:
"""
Get predictor
Needs to be overwritten in subclasses!
"""
raise NotImplementedError
def sweep(self,
cfg: dict,
save_dir: os.PathLike,
train_data_dir: os.PathLike,
case_ids: Sequence[str],
run_prediction: bool = True,
) -> Dict[str, Any]:
"""
Sweep parameters to find the best predictions
Needs to be overwritten in subclasses!
Args:
cfg: config used for training
save_dir: save dir used for training
train_data_dir: directory where preprocessed training/validation
data is located
case_ids: case identifies to prepare and predict
run_prediction: predict cases
**kwargs: keyword arguments passed to predict function
"""
raise NotImplementedError
class LightningBaseModuleSWA(LightningBaseModule):
@property
def max_epochs(self):
"""
Number of epochs to train
"""
return self.trainer_cfg["max_num_epochs"] + self.trainer_cfg["swa_epochs"]
def configure_callbacks(self):
from nndet.training.swa import SWACycleLinear
callbacks = []
callbacks.append(
SWACycleLinear(
swa_epoch_start=self.trainer_cfg["max_num_epochs"],
cycle_initial_lr=self.trainer_cfg["initial_lr"] / 10.,
cycle_final_lr=self.trainer_cfg["initial_lr"] / 1000.,
num_iterations_per_epoch=self.trainer_cfg["num_train_batches_per_epoch"],
)
)
return callbacks
LightningBaseModuleType = TypeVar('LightningBaseModuleType', bound=LightningBaseModule)
from nndet.ptmodule.retinaunet.base import RetinaUNetModule
from nndet.ptmodule.retinaunet.v001 import RetinaUNetV001
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations
import os
import copy
from collections import defaultdict
from pathlib import Path
from functools import partial
from typing import Callable, Hashable, Sequence, Dict, Any, Type
import torch
import numpy as np
from loguru import logger
from torchvision.models.detection.rpn import AnchorGenerator
from nndet.utils.tensor import to_numpy
from nndet.evaluator.det import BoxEvaluator
from nndet.evaluator.seg import SegmentationEvaluator
from nndet.core.retina import BaseRetinaNet
from nndet.core.boxes.matcher import IoUMatcher
from nndet.core.boxes.sampler import HardNegativeSamplerBatched
from nndet.core.boxes.coder import CoderType, BoxCoderND
from nndet.core.boxes.anchors import get_anchor_generator
from nndet.core.boxes.ops import box_iou
from nndet.core.boxes.anchors import AnchorGeneratorType
from nndet.ptmodule.base_module import LightningBaseModuleSWA, LightningBaseModule
from nndet.arch.conv import Generator, ConvInstanceRelu, ConvGroupRelu
from nndet.arch.blocks.basic import StackedConvBlock2
from nndet.arch.encoder.abstract import EncoderType
from nndet.arch.encoder.modular import Encoder
from nndet.arch.decoder.base import DecoderType, BaseUFPN, UFPNModular
from nndet.arch.heads.classifier import ClassifierType, CEClassifier
from nndet.arch.heads.regressor import RegressorType, L1Regressor
from nndet.arch.heads.comb import HeadType, DetectionHeadHNM
from nndet.arch.heads.segmenter import SegmenterType, DiCESegmenter
from nndet.training.optimizer import get_params_no_wd_on_norm
from nndet.training.learning_rate import LinearWarmupPolyLR
from nndet.inference.predictor import Predictor
from nndet.inference.sweeper import BoxSweeper
from nndet.inference.transforms import get_tta_transforms, Inference2D
from nndet.inference.loading import get_loader_fn
from nndet.inference.helper import predict_dir
from nndet.inference.ensembler.segmentation import SegmentationEnsembler
from nndet.inference.ensembler.detection import BoxEnsemblerSelective
from nndet.io.transforms import (
Compose,
Instances2Boxes,
Instances2Segmentation,
FindInstances,
)
class RetinaUNetModule(LightningBaseModuleSWA):
base_conv_cls = ConvInstanceRelu
head_conv_cls = ConvGroupRelu
block = StackedConvBlock2
encoder_cls = Encoder
decoder_cls = UFPNModular
matcher_cls = IoUMatcher
head_cls = DetectionHeadHNM
head_classifier_cls = CEClassifier
head_regressor_cls = L1Regressor
head_sampler_cls = HardNegativeSamplerBatched
segmenter_cls = DiCESegmenter
def __init__(self,
model_cfg: dict,
trainer_cfg: dict,
plan: dict,
**kwargs
):
"""
RetinaUNet Lightning Module Skeleton
Args:
model_cfg: model configuration. Check :method:`from_config_plan`
for more information
trainer_cfg: trainer information
plan: contains parameters which were derived from the planning
stage
"""
super().__init__(
model_cfg=model_cfg,
trainer_cfg=trainer_cfg,
plan=plan,
)
_classes = [f"class{c}" for c in range(plan["architecture"]["classifier_classes"])]
self.box_evaluator = BoxEvaluator.create(
classes=_classes,
fast=True,
save_dir=None,
)
self.seg_evaluator = SegmentationEvaluator.create()
self.pre_trafo = Compose(
FindInstances(
instance_key="target",
save_key="present_instances",
),
Instances2Boxes(
instance_key="target",
map_key="instance_mapping",
box_key="boxes",
class_key="classes",
present_instances="present_instances",
),
Instances2Segmentation(
instance_key="target",
map_key="instance_mapping",
present_instances="present_instances",
)
)
self.eval_score_key = "mAP_IoU_0.10_0.50_0.05_MaxDet_100"
def training_step(self, batch, batch_idx):
"""
Computes a single training step
See :class:`BaseRetinaNet` for more information
"""
with torch.no_grad():
batch = self.pre_trafo(**batch)
losses, _ = self.model.train_step(
images=batch["data"],
targets={
"target_boxes": batch["boxes"],
"target_classes": batch["classes"],
"target_seg": batch['target'][:, 0] # Remove channel dimension
},
evaluation=False,
batch_num=batch_idx,
)
loss = sum(losses.values())
return {"loss": loss, **{key: l.detach().item() for key, l in losses.items()}}
def validation_step(self, batch, batch_idx):
"""
Computes a single validation step (same as train step but with
additional prediciton processing)
See :class:`BaseRetinaNet` for more information
"""
with torch.no_grad():
batch = self.pre_trafo(**batch)
targets = {
"target_boxes": batch["boxes"],
"target_classes": batch["classes"],
"target_seg": batch['target'][:, 0] # Remove channel dimension
}
losses, prediction = self.model.train_step(
images=batch["data"],
targets=targets,
evaluation=True,
batch_num=batch_idx,
)
loss = sum(losses.values())
self.evaluation_step(prediction=prediction, targets=targets)
return {"loss": loss.detach().item(),
**{key: l.detach().item() for key, l in losses.items()}}
def evaluation_step(
self,
prediction: dict,
targets: dict,
):
"""
Perform an evaluation step to add predictions and gt to
caching mechanism which is evaluated at the end of the epoch
Args:
prediction: predictions obtained from model
'pred_boxes': List[Tensor]: predicted bounding boxes for
each image List[[R, dim * 2]]
'pred_scores': List[Tensor]: predicted probability for
the class List[[R]]
'pred_labels': List[Tensor]: predicted class List[[R]]
'pred_seg': Tensor: predicted segmentation [N, dims]
targets: ground truth
`target_boxes` (List[Tensor]): ground truth bounding boxes
(x1, y1, x2, y2, (z1, z2))[X, dim * 2], X= number of ground
truth boxes in image
`target_classes` (List[Tensor]): ground truth class per box
(classes start from 0) [X], X= number of ground truth
boxes in image
`target_seg` (Tensor): segmentation ground truth (if seg was
found in input dict)
"""
pred_boxes = to_numpy(prediction["pred_boxes"])
pred_classes = to_numpy(prediction["pred_labels"])
pred_scores = to_numpy(prediction["pred_scores"])
gt_boxes = to_numpy(targets["target_boxes"])
gt_classes = to_numpy(targets["target_classes"])
gt_ignore = None
self.box_evaluator.run_online_evaluation(
pred_boxes=pred_boxes,
pred_classes=pred_classes,
pred_scores=pred_scores,
gt_boxes=gt_boxes,
gt_classes=gt_classes,
gt_ignore=gt_ignore,
)
pred_seg = to_numpy(prediction["pred_seg"])
gt_seg = to_numpy(targets["target_seg"])
self.seg_evaluator.run_online_evaluation(
seg_probs=pred_seg,
target=gt_seg,
)
def training_epoch_end(self, training_step_outputs):
"""
Log train loss to loguru logger
"""
# process and log losses
vals = defaultdict(list)
for _val in training_step_outputs:
for _k, _v in _val.items():
if _k == "loss":
vals[_k].append(_v.detach().item())
else:
vals[_k].append(_v)
for _key, _vals in vals.items():
mean_val = np.mean(_vals)
if _key == "loss":
logger.info(f"Train loss reached: {mean_val:0.5f}")
self.log(f"train_{_key}", mean_val, sync_dist=True)
return super().training_epoch_end(training_step_outputs)
def validation_epoch_end(self, validation_step_outputs):
"""
Log val loss to loguru logger
"""
# process and log losses
vals = defaultdict(list)
for _val in validation_step_outputs:
for _k, _v in _val.items():
vals[_k].append(_v)
for _key, _vals in vals.items():
mean_val = np.mean(_vals)
if _key == "loss":
logger.info(f"Val loss reached: {mean_val:0.5f}")
self.log(f"val_{_key}", mean_val, sync_dist=True)
# process and log metrics
self.evaluation_end()
return super().validation_epoch_end(validation_step_outputs)
def evaluation_end(self):
"""
Uses the cached values from `evaluation_step` to perform the evaluation
of the epoch
"""
metric_scores, _ = self.box_evaluator.finish_online_evaluation()
self.box_evaluator.reset()
logger.info(f"mAP@0.1:0.5:0.05: {metric_scores['mAP_IoU_0.10_0.50_0.05_MaxDet_100']:0.3f} "
f"AP@0.1: {metric_scores['AP_IoU_0.10_MaxDet_100']:0.3f} "
f"AP@0.5: {metric_scores['AP_IoU_0.50_MaxDet_100']:0.3f}")
seg_scores, _ = self.seg_evaluator.finish_online_evaluation()
self.seg_evaluator.reset()
metric_scores.update(seg_scores)
logger.info(f"Proxy FG Dice: {seg_scores['seg_dice']:0.3f}")
for key, item in metric_scores.items():
self.log(f'{key}', item, on_step=None, on_epoch=True, prog_bar=False, logger=True)
def configure_optimizers(self):
"""
Configure optimizer and scheduler
Base configuration is SGD with LinearWarmup and PolyLR learning rate
schedule
"""
# configure optimizer
logger.info(f"Running: initial_lr {self.trainer_cfg['initial_lr']} "
f"weight_decay {self.trainer_cfg['weight_decay']} "
f"SGD with momentum {self.trainer_cfg['sgd_momentum']} and "
f"nesterov {self.trainer_cfg['sgd_nesterov']}")
wd_groups = get_params_no_wd_on_norm(self, weight_decay=self.trainer_cfg['weight_decay'])
optimizer = torch.optim.SGD(
wd_groups,
self.trainer_cfg["initial_lr"],
weight_decay=self.trainer_cfg["weight_decay"],
momentum=self.trainer_cfg["sgd_momentum"],
nesterov=self.trainer_cfg["sgd_nesterov"],
)
# configure lr scheduler
num_iterations = self.trainer_cfg["max_num_epochs"] * \
self.trainer_cfg["num_train_batches_per_epoch"]
scheduler = LinearWarmupPolyLR(
optimizer=optimizer,
warm_iterations=self.trainer_cfg["warm_iterations"],
warm_lr=self.trainer_cfg["warm_lr"],
poly_gamma=self.trainer_cfg["poly_gamma"],
num_iterations=num_iterations
)
return [optimizer], {'scheduler': scheduler, 'interval': 'step'}
@classmethod
def from_config_plan(cls,
model_cfg: dict,
plan_arch: dict,
plan_anchors: dict,
log_num_anchors: str = None,
**kwargs,
):
"""
Create Configurable RetinaUNet
Args:
model_cfg: model configurations
See example configs for more info
plan_arch: plan architecture
`dim` (int): number of spatial dimensions
`in_channels` (int): number of input channels
`classifier_classes` (int): number of classes
`seg_classes` (int): number of classes
`start_channels` (int): number of start channels in encoder
`fpn_channels` (int): number of channels to use for FPN
`head_channels` (int): number of channels to use for head
`decoder_levels` (int): decoder levels to user for detection
plan_anchors: parameters for anchors (see
:class:`AnchorGenerator` for more info)
`stride`: stride
`aspect_ratios`: aspect ratios
`sizes`: sized for 2d acnhors
(`zsizes`: additional z sizes for 3d)
log_num_anchors: name of logger to use; if None, no logging
will be performed
**kwargs:
"""
logger.info(f"Architecture overwrites: {model_cfg['plan_arch_overwrites']} "
f"Anchor overwrites: {model_cfg['plan_anchors_overwrites']}")
logger.info(f"Building architecture according to plan of {plan_arch.get('arch_name', 'not_found')}")
plan_arch.update(model_cfg["plan_arch_overwrites"])
plan_anchors.update(model_cfg["plan_anchors_overwrites"])
logger.info(f"Start channels: {plan_arch['start_channels']}; "
f"head channels: {plan_arch['head_channels']}; "
f"fpn channels: {plan_arch['fpn_channels']}")
_plan_anchors = copy.deepcopy(plan_anchors)
coder = BoxCoderND(weights=(1.,) * (plan_arch["dim"] * 2))
s_param = False if ("aspect_ratios" in _plan_anchors) and \
(_plan_anchors["aspect_ratios"] is not None) else True
anchor_generator = get_anchor_generator(
plan_arch["dim"], s_param=s_param)(**_plan_anchors)
encoder = cls._build_encoder(
plan_arch=plan_arch,
model_cfg=model_cfg,
)
decoder = cls._build_decoder(
encoder=encoder,
plan_arch=plan_arch,
model_cfg=model_cfg,
)
matcher = cls.matcher_cls(
similarity_fn=box_iou,
**model_cfg["matcher_kwargs"],
)
classifier = cls._build_head_classifier(
plan_arch=plan_arch,
model_cfg=model_cfg,
anchor_generator=anchor_generator,
)
regressor = cls._build_head_regressor(
plan_arch=plan_arch,
model_cfg=model_cfg,
anchor_generator=anchor_generator,
)
head = cls._build_head(
plan_arch=plan_arch,
model_cfg=model_cfg,
classifier=classifier,
regressor=regressor,
coder=coder
)
segmenter = cls._build_segmenter(
plan_arch=plan_arch,
model_cfg=model_cfg,
decoder=decoder,
)
detections_per_img = plan_arch.get("detections_per_img", 100)
score_thresh = plan_arch.get("score_thresh", 0)
topk_candidates = plan_arch.get("topk_candidates", 10000)
remove_small_boxes = plan_arch.get("remove_small_boxes", 0.01)
nms_thresh = plan_arch.get("nms_thresh", 0.6)
logger.info(f"Model Inference Summary: \n"
f"detections_per_img: {detections_per_img} \n"
f"score_thresh: {score_thresh} \n"
f"topk_candidates: {topk_candidates} \n"
f"remove_small_boxes: {remove_small_boxes} \n"
f"nms_thresh: {nms_thresh}",
)
return BaseRetinaNet(
dim=plan_arch["dim"],
encoder=encoder,
decoder=decoder,
head=head,
anchor_generator=anchor_generator,
matcher=matcher,
num_classes=plan_arch["classifier_classes"],
decoder_levels=plan_arch["decoder_levels"],
segmenter=segmenter,
# model_max_instances_per_batch_element (in mdt per img, per class; here: per img)
detections_per_img=detections_per_img,
score_thresh=score_thresh,
topk_candidates=topk_candidates,
remove_small_boxes=remove_small_boxes,
nms_thresh=nms_thresh,
)
@classmethod
def _build_encoder(
cls,
plan_arch: dict,
model_cfg: dict,
) -> EncoderType:
"""
Build encoder network
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
Returns:
EncoderType: encoder instance
"""
conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: encoder {cls.encoder_cls.__name__}: {model_cfg['encoder_kwargs']} ")
encoder = cls.encoder_cls(
conv=conv,
conv_kernels=plan_arch["conv_kernels"],
strides=plan_arch["strides"],
block_cls=cls.block,
in_channels=plan_arch["in_channels"],
start_channels=plan_arch["start_channels"],
stage_kwargs=None,
max_channels=plan_arch.get("max_channels", 320),
**model_cfg['encoder_kwargs'],
)
return encoder
@classmethod
def _build_decoder(
cls,
plan_arch: dict,
model_cfg: dict,
encoder: EncoderType,
) -> DecoderType:
"""
Build decoder network
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
Returns:
DecoderType: decoder instance
"""
conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: decoder {cls.decoder_cls.__name__}: {model_cfg['decoder_kwargs']}")
decoder = cls.decoder_cls(
conv=conv,
conv_kernels=plan_arch["conv_kernels"],
strides=encoder.get_strides(),
in_channels=encoder.get_channels(),
decoder_levels=plan_arch["decoder_levels"],
fixed_out_channels=plan_arch["fpn_channels"],
**model_cfg['decoder_kwargs'],
)
return decoder
@classmethod
def _build_head_classifier(
cls,
plan_arch: dict,
model_cfg: dict,
anchor_generator: AnchorGeneratorType,
) -> ClassifierType:
"""
Build classification subnetwork for detection head
Args:
anchor_generator: anchor generator instance
plan_arch: architecture settings
model_cfg: additional architecture settings
Returns:
ClassifierType: classification instance
"""
conv = Generator(cls.head_conv_cls, plan_arch["dim"])
name = cls.head_classifier_cls.__name__
kwargs = model_cfg['head_classifier_kwargs']
logger.info(f"Building:: classifier {name}: {kwargs}")
classifier = cls.head_classifier_cls(
conv=conv,
in_channels=plan_arch["fpn_channels"],
internal_channels=plan_arch["head_channels"],
num_classes=plan_arch["classifier_classes"],
anchors_per_pos=anchor_generator.num_anchors_per_location()[0],
num_levels=len(plan_arch["decoder_levels"]),
**kwargs,
)
return classifier
@classmethod
def _build_head_regressor(
cls,
plan_arch: dict,
model_cfg: dict,
anchor_generator: AnchorGeneratorType,
) -> RegressorType:
"""
Build regression subnetwork for detection head
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
anchor_generator: anchor generator instance
Returns:
RegressorType: classification instance
"""
conv = Generator(cls.head_conv_cls, plan_arch["dim"])
name = cls.head_regressor_cls.__name__
kwargs = model_cfg['head_regressor_kwargs']
logger.info(f"Building:: regressor {name}: {kwargs}")
regressor = cls.head_regressor_cls(
conv=conv,
in_channels=plan_arch["fpn_channels"],
internal_channels=plan_arch["head_channels"],
anchors_per_pos=anchor_generator.num_anchors_per_location()[0],
num_levels=len(plan_arch["decoder_levels"]),
**kwargs,
)
return regressor
@classmethod
def _build_head(
cls,
plan_arch: dict,
model_cfg: dict,
classifier: ClassifierType,
regressor: RegressorType,
coder: CoderType,
) -> HeadType:
"""
Build detection head
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
classifier: classifier instance
regressor: regressor instance
coder: coder instance to encode boxes
Returns:
HeadType: instantiated head
"""
head_name = cls.head_cls.__name__
head_kwargs = model_cfg['head_kwargs']
sampler_name = cls.head_sampler_cls.__name__
sampler_kwargs = model_cfg['head_sampler_kwargs']
logger.info(f"Building:: head {head_name}: {head_kwargs} "
f"sampler {sampler_name}: {sampler_kwargs}")
sampler = cls.head_sampler_cls(**sampler_kwargs)
head = cls.head_cls(
classifier=classifier,
regressor=regressor,
coder=coder,
sampler=sampler,
log_num_anchors=None,
**head_kwargs,
)
return head
@classmethod
def _build_segmenter(
cls,
plan_arch: dict,
model_cfg: dict,
decoder: DecoderType,
) -> SegmenterType:
"""
Build segmenter head
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
decoder: decoder instance
Returns:
SegmenterType: segmenter head
"""
if cls.segmenter_cls is not None:
name = cls.segmenter_cls.__name__
kwargs = model_cfg['segmenter_kwargs']
conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: segmenter {name} {kwargs}")
segmenter = cls.segmenter_cls(
conv,
seg_classes=plan_arch["seg_classes"],
in_channels=decoder.get_channels(),
decoder_levels=plan_arch["decoder_levels"],
**kwargs,
)
else:
segmenter = None
return segmenter
@staticmethod
def get_ensembler_cls(key: Hashable, dim: int) -> Callable:
"""
Get ensembler classes to combine multiple predictions
Needs to be overwritten in subclasses!
"""
_lookup = {
2: {
"boxes": None,
"seg": None,
},
3: {
"boxes": BoxEnsemblerSelective,
"seg": SegmentationEnsembler,
}
}
if dim == 2:
raise NotImplementedError
return _lookup[dim][key]
@classmethod
def get_predictor(cls,
plan: Dict,
models: Sequence[RetinaUNetModule],
num_tta_transforms: int = None,
do_seg: bool = False,
**kwargs,
) -> Predictor:
# process plan
crop_size = plan["patch_size"]
batch_size = plan["batch_size"]
inferene_plan = plan.get("inference_plan", {})
logger.info(f"Found inference plan: {inferene_plan} for prediction")
if num_tta_transforms is None:
num_tta_transforms = 8 if plan["network_dim"] == 3 else 4
# setup
tta_transforms, tta_inverse_transforms = \
get_tta_transforms(num_tta_transforms, True)
logger.info(f"Using {len(tta_transforms)} tta transformations for prediction (one dummy trafo).")
ensembler = {"boxes": partial(
cls.get_ensembler_cls(key="boxes", dim=plan["network_dim"]).from_case,
parameters=inferene_plan,
)}
if do_seg:
ensembler["seg"] = partial(
cls.get_ensembler_cls(key="seg", dim=plan["network_dim"]).from_case,
)
predictor = Predictor(
ensembler=ensembler,
models=models,
crop_size=crop_size,
tta_transforms=tta_transforms,
tta_inverse_transforms=tta_inverse_transforms,
batch_size=batch_size,
**kwargs,
)
if plan["network_dim"] == 2:
raise NotImplementedError
predictor.pre_transform = Inference2D(["data"])
return predictor
def sweep(self,
cfg: dict,
save_dir: os.PathLike,
train_data_dir: os.PathLike,
case_ids: Sequence[str],
run_prediction: bool = True,
**kwargs,
) -> Dict[str, Any]:
"""
Sweep detection parameters to find the best predictions
Args:
cfg: config used for training
save_dir: save dir used for training
train_data_dir: directory where preprocessed training/validation
data is located
case_ids: case identifies to prepare and predict
run_prediction: predict cases
**kwargs: keyword arguments passed to predict function
Returns:
Dict: inference plan
e.g. (exact params depend on ensembler class usef for prediction)
`iou_thresh` (float): best IoU threshold
`score_thresh (float)`: best score threshold
`no_overlap` (bool): enable/disable class independent NMS (ciNMS)
"""
logger.info(f"Running parameter sweep on {case_ids}")
train_data_dir = Path(train_data_dir)
preprocessed_dir = train_data_dir.parent
processed_eval_labels = preprocessed_dir / "labelsTr"
_save_dir = save_dir / "sweep"
_save_dir.mkdir(parents=True, exist_ok=True)
prediction_dir = save_dir / "sweep_predictions"
prediction_dir.mkdir(parents=True, exist_ok=True)
if run_prediction:
logger.info("Predict cases with default settings...")
predictor = predict_dir(
source_dir=train_data_dir,
target_dir=prediction_dir,
cfg=cfg,
plan=self.plan,
source_models=save_dir,
num_models=1,
num_tta_transforms=None,
case_ids=case_ids,
save_state=True,
model_fn=get_loader_fn(mode=self.trainer_cfg.get("sweep_ckpt", "last")),
**kwargs,
)
logger.info("Start parameter sweep...")
ensembler_cls = self.get_ensembler_cls(key="boxes", dim=self.plan["network_dim"])
sweeper = BoxSweeper(
classes=[item for _, item in cfg["data"]["labels"].items()],
pred_dir=prediction_dir,
gt_dir=processed_eval_labels,
target_metric=self.eval_score_key,
ensembler_cls=ensembler_cls,
save_dir=_save_dir,
)
inference_plan = sweeper.run_postprocessing_sweep()
return inference_plan
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from nndet.ptmodule.retinaunet.base import RetinaUNetModule
from nndet.core.boxes.matcher import ATSSMatcher
from nndet.arch.heads.classifier import BCECLassifier
from nndet.arch.heads.regressor import GIoURegressor
from nndet.arch.heads.comb import DetectionHeadHNMNative
from nndet.arch.heads.segmenter import DiCESegmenterFgBg
from nndet.arch.conv import ConvInstanceRelu, ConvGroupRelu
from nndet.ptmodule import MODULE_REGISTRY
@MODULE_REGISTRY.register
class RetinaUNetV001(RetinaUNetModule):
base_conv_cls = ConvInstanceRelu
head_conv_cls = ConvGroupRelu
head_cls = DetectionHeadHNMNative
head_classifier_cls = BCECLassifier
head_regressor_cls = GIoURegressor
matcher_cls = ATSSMatcher
segmenter_cls = DiCESegmenterFgBg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment