Commit c61a4ce0 authored by mibaumgartner's avatar mibaumgartner
Browse files

WIP

parent cbe90756
#include <torch/extension.h>
#include "cpu/nms.cpp"
#include "cpu/roi_align.cpp"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "NMS C++ and/or CUDA");
......
from abc import ABC, abstractmethod
from typing import TypeVar
class ArchitecturePlanner(ABC):
def __init__(self, **kwargs):
"""
Plan architecture and training hyperparameters (batch size and patch size)
"""
for key, item in kwargs.items():
setattr(self, key, item)
@abstractmethod
def plan(self, *args, **kwargs) -> dict:
"""
Plan architecture and training parameters
Args:
*args: positional arguments determined by Planner
**kwargs: keyword arguments determined by Planner
Returns:
dict: training and architecture information
`patch_size` (Sequence[int]): patch size
`batch_size` (int): batch size for training
`architecture` (dict): dictionary with all parameters needed for the final model
"""
raise NotImplementedError
def approximate_vram(self):
"""
Approximate vram usage of model for planning
"""
pass
def get_planner_id(self) -> str:
"""
Create identifier for this planner
Returns:
str: identifier
"""
return self.__class__.__name__
ArchitecturePlannerType = TypeVar('ArchitecturePlannerType', bound=ArchitecturePlanner)
from nndet.planning.architecture.boxes.base import BaseBoxesPlanner
from nndet.planning.architecture.boxes.c002 import BoxC002
This diff is collapsed.
import os
import copy
from typing import Callable, Sequence, List
import torch
import numpy as np
from loguru import logger
from nndet.planning.estimator import MemoryEstimator, MemoryEstimatorDetection
from nndet.planning.architecture.boxes.base import BaseBoxesPlanner
from nndet.planning.architecture.boxes.utils import (
proxy_num_boxes_in_patch,
scale_with_abs_strides,
)
from nndet.core.boxes import (
get_anchor_generator,
expand_to_boxes,
box_center,
box_size_np,
permute_boxes,
)
class BoxC002(BaseBoxesPlanner):
def __init__(self,
preprocessed_output_dir: os.PathLike,
save_dir: os.PathLike,
network_cls: Callable,
estimator: MemoryEstimator = MemoryEstimatorDetection(),
model_cfg: dict = None,
**kwargs,
):
super().__init__(
preprocessed_output_dir=preprocessed_output_dir,
save_dir=save_dir,
network_cls=network_cls,
estimator=estimator,
model_cfg=model_cfg,
**kwargs
)
def create_default_settings(self):
"""
Generate default settings for the architecture
"""
super().create_default_settings()
self.architecture_kwargs["start_channels"] = 48 if self.dim == 2 else 32
self.architecture_kwargs["fpn_channels"] = \
self.architecture_kwargs["start_channels"] * 4
self.architecture_kwargs["head_channels"] = \
self.architecture_kwargs["fpn_channels"]
self.batch_size = 16 if self.dim == 2 else 4
self.min_feature_map_size = 8 if self.dim == 2 else 4
self.num_decoder_level = 5 if self.dim == 2 else 4
def get_anchor_init(self, boxes: torch.Tensor) -> Sequence[Sequence[int]]:
"""
Initialize anchors sizes for optimization
Args:
boxes: scales and transposed boxes
Returns:
Sequence[Sequence[int]]: anchor initialization
"""
box_dim = int(boxes.shape[1]) // 2
return [(4, 8, 16), ] * box_dim
def process_properties(self, **kwargs):
"""
Load dataset properties and extract information
"""
logger.info("Processing dataset properties")
self.all_boxes = [case["boxes"] for case_id, case
in self.dataset_properties["instance_props_per_patient"].items()]
self.all_spacings = [case["original_spacing"] for case_id, case
in self.dataset_properties["instance_props_per_patient"].items()]
self.num_instances_per_case = {case_id: sum(case["num_instances"].values())
for case_id, case in self.dataset_properties["instance_props_per_patient"].items()}
self.all_ious = self.dataset_properties["all_ious"]
self.class_ious = self.dataset_properties["class_ious"]
self.num_instances = self.dataset_properties["num_instances"]
self.dim = self.dataset_properties["dim"]
self.architecture_kwargs["classifier_classes"] = \
len(self.dataset_properties["class_dct"])
self.architecture_kwargs["seg_classes"] = \
self.architecture_kwargs["classifier_classes"]
self.architecture_kwargs["in_channels"] = \
len(self.dataset_properties["modalities"])
self.architecture_kwargs["dim"] = \
self.dataset_properties["dim"]
def plan(self,
target_spacing_transposed: Sequence[float],
median_shape_transposed: Sequence[float],
transpose_forward: Sequence[int],
mode: str = '3d',
) -> dict:
"""
Plan network architecture, anchors, patch size and batch size
Args:
target_spacing_transposed: spacing after data is transposed and resampled
median_shape_transposed: median shape after data is
transposed and resampled
transpose_forward: new ordering of axes for forward pass
mode: mode to use for planning ('3d' | '2d')
Returns:
dict: training and architecture information
See Also:
:method:`_plan_architecture`, :method:`_plan_anchors`
"""
if mode == "2d":
logger.info("Running 2d mode")
self.process_properties()
kwargs_2d = self.activate_2d_mode(
transpose_forward=transpose_forward,
target_spacing_transposed=target_spacing_transposed,
median_shape_transposed=median_shape_transposed,
)
res = super().plan(**kwargs_2d)
else:
res = super().plan(
transpose_forward=transpose_forward,
target_spacing_transposed=target_spacing_transposed,
median_shape_transposed=median_shape_transposed,
)
return res
def activate_2d_mode(self,
target_spacing_transposed: Sequence[float],
median_shape_transposed: Sequence[float],
transpose_forward: Sequence[int],
) -> dict:
target_spacing_transposed = target_spacing_transposed[1:]
median_shape_transposed = median_shape_transposed[1:]
keep = copy.copy(transpose_forward[1:])
transpose_forward = [t - 1 for t in keep]
keep_box = [0, 0, 0, 0]
for idx, k in enumerate(keep):
if k < 2:
keep_box[idx] = k
keep_box[idx + 2] = k + 2
else:
keep_box[idx] = 2 * k
keep_box[idx + 2] = 2 * k + 1
self.all_boxes = [b[:, keep_box] if (not isinstance(b, list) and b.shape[1] == 6) else b
for b in self.all_boxes]
self.all_spacings = [c[keep] if len(c) == 3 else c for c in self.all_spacings]
self.dim = 2
self.architecture_kwargs["dim"] = self.dim
return {
"target_spacing_transposed": target_spacing_transposed,
"median_shape_transposed": median_shape_transposed,
"transpose_forward": transpose_forward,
}
def _plan_architecture(self,
target_spacing_transposed: Sequence[float],
target_median_shape_transposed: Sequence[float],
transpose_forward: Sequence[int],
**kwargs,
) -> Sequence[int]:
"""
Plan patch size and main aspects of the architecture
Fills entries in :param:`self.architecture_kwargs`:
`conv_kernels`
`strides`
`decoder_levels`
Args:
target_spacing_transposed: spacing after data is transposed and resampled
target_median_shape_transposed: median shape after data is
transposed and resampled
Returns:
Sequence[int]: patch size to use for training
"""
self.estimator.batch_size = self.batch_size
patch_size = np.asarray(self._get_initial_patch_size(
target_spacing_transposed, target_median_shape_transposed))
first_run = True
while True:
if first_run:
pass
else:
patch_size = self._decrease_patch_size(
patch_size, target_median_shape_transposed, pooling, must_be_divisible_by)
num_pool_per_axis, pooling, convs, patch_size, must_be_divisible_by = \
self.plan_pool_and_conv_pool_late(patch_size, target_spacing_transposed)
self.architecture_kwargs["conv_kernels"] = convs
self.architecture_kwargs["strides"] = pooling
num_resolutions = len(self.architecture_kwargs["conv_kernels"])
decoder_levels_start = min(max(1, num_resolutions - self.num_decoder_level), self.min_decoder_level)
self.architecture_kwargs["decoder_levels"] = \
tuple([i for i in range(decoder_levels_start, num_resolutions)])
_, fits_in_mem = self.estimator.estimate(
min_shape=must_be_divisible_by,
target_shape=patch_size,
in_channels=self.architecture_kwargs["in_channels"],
network=self.network_cls.from_config_plan(
model_cfg=self.model_cfg,
plan_arch=self.architecture_kwargs,
plan_anchors=self.get_anchors_for_estimation()),
optimizer_cls=torch.optim.Adam,
num_instances=self._estimte_num_instances_per_patch(
patch_size=patch_size,
target_spacing_transposed=target_spacing_transposed,
transpose_forward=transpose_forward,
),
)
if fits_in_mem:
break
first_run = False
logger.info(f"decoder levels: {self.architecture_kwargs['decoder_levels']}; \n"
f"pooling strides: {self.architecture_kwargs['strides']}; \n"
f"kernel sizes: {self.architecture_kwargs['conv_kernels']}; \n"
f"patch size: {patch_size}; \n")
return patch_size
def _estimte_num_instances_per_patch(self,
patch_size,
target_spacing_transposed,
transpose_forward,
) -> int:
max_instances_per_image = []
for boxes in self._get_scaled_boxes(
target_spacing_transposed=target_spacing_transposed,
transpose_forward=transpose_forward,
cat=False,
):
max_instances_per_image.append(
max(proxy_num_boxes_in_patch(torch.from_numpy(boxes), patch_size)).item())
return max(max_instances_per_image)
def _plan_anchors(self,
target_spacing_transposed: Sequence[float],
transpose_forward: Sequence[int],
**kwargs,
) -> dict:
"""
Optimize anchors
"""
boxes_np_full = self._get_scaled_boxes(
target_spacing_transposed=target_spacing_transposed,
transpose_forward=transpose_forward,
)
boxes_np = self.filter_boxes(boxes_np_full)
logger.info(f"Filtered {boxes_np_full.shape[0] - boxes_np.shape[0]} "
f"boxes, {boxes_np.shape[0]} boxes remaining for anchor "
"planning.")
boxes_torch = torch.from_numpy(boxes_np).float()
boxes_torch = boxes_torch - expand_to_boxes(box_center(boxes_torch))
anchor_generator = get_anchor_generator(self.dim, s_param=True)
rel_strides = self.architecture_kwargs["strides"]
filt_rel_strides = [[1] * self.dim, *rel_strides]
filt_rel_strides = [filt_rel_strides[i] for i in self.architecture_kwargs["decoder_levels"]]
strides = np.cumprod(filt_rel_strides, axis=0) / np.asarray(rel_strides[0])
params = self.find_anchors(boxes_torch, strides.astype(np.int32), anchor_generator)
scaled_params = {key: scale_with_abs_strides(item, strides, dim_idx) for dim_idx, (key, item) in enumerate(params.items())}
logger.info(f"Determined Anchors: {params}; Results in params: {scaled_params}")
self.anchors = scaled_params
self.anchors["stride"] = 1
return self.anchors
def _get_scaled_boxes(self,
target_spacing_transposed: Sequence[float],
transpose_forward: Sequence[int],
cat: bool = True,
) -> np.ndarray:
"""
training is conducted in preprocessed image space and thus
we need to scale the extracted boxes to compensate for resampling
"""
boxes_np_list = []
for spacing, boxes in zip(self.all_spacings, self.all_boxes):
if not isinstance(boxes, list) and boxes.size > 0:
spacing_transposed = np.asarray(spacing)[transpose_forward]
scaling_transposed = spacing_transposed / np.asarray(target_spacing_transposed)
boxes_transposed = permute_boxes(np.asarray(boxes), dims=transpose_forward)
boxes_np_list.append(boxes_transposed * expand_to_boxes(scaling_transposed))
if cat:
return np.concatenate(boxes_np_list).astype(np.float32)
else:
return boxes_np_list
@staticmethod
def _get_initial_patch_size(target_spacing_transposed: np.ndarray,
target_median_shape_transposed: Sequence[int],
) -> List[int]:
"""
Generate initial patch which relies on the spacing of underlying images.
This is based on the fact that most acquisition protocols are optimized
to focus on the most importatnt aspects.
Returns:
List[int]: initial patch size
"""
voxels_per_mm = 1 / np.array(target_spacing_transposed)
# normalize voxels per mm
input_patch_size = voxels_per_mm / voxels_per_mm.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(np.int32)
# clip it to the median shape of the dataset because patches larger
# then that make not much sense and account for recangular patches
if len(target_spacing_transposed) > 2:
lowres_axis = np.argmax(target_spacing_transposed)
isotropic_axes = list(range(len(target_median_shape_transposed)))
isotropic_axes.pop(lowres_axis)
min_isotropic_axes_shape = min([target_median_shape_transposed[t] for t in isotropic_axes])
lowres_shape = target_median_shape_transposed[lowres_axis]
else:
lowres_axis = -1
lowres_shape = None
min_isotropic_axes_shape = min(target_median_shape_transposed)
initial_patch_size = []
for i in range(len(target_median_shape_transposed)):
if i == lowres_axis:
assert lowres_shape is not None
initial_patch_size.append(min(input_patch_size[i], lowres_shape))
else:
initial_patch_size.append(min(input_patch_size[i], min_isotropic_axes_shape))
initial_patch_size = np.round(initial_patch_size).astype(np.int32)
logger.info(f"Using initial patch size: {initial_patch_size}")
return initial_patch_size
def plot_box_distribution(self,
target_spacing_transposed: Sequence[float],
transpose_forward: Sequence[int],
**kwargs):
"""
Plot histogram with ground truth bounding box distribution for
all axis
"""
super().plot_box_distribution()
try:
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
import matplotlib.pyplot as plt
except ImportError:
logger.error("Failed to import matplotlib continue anyway.")
plt = None
if plt is not None:
if isinstance(self.all_boxes, list):
_boxes = np.concatenate(
[b for b in self.all_boxes if not isinstance(b, list) and b.size > 0], axis=0)
dists = box_size_np(_boxes)
else:
dists = box_size_np(self.all_boxes)
if dists.shape[1] == 3:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dists[:, 0], dists[:, 1], dists[:, 2])
ax.set_title(f"Transpose forward {transpose_forward}")
plt.savefig(self.save_dir / f'bbox_sizes_3d_orig.png')
plt.close()
dists = box_size_np(self._get_scaled_boxes(
target_spacing_transposed, transpose_forward))
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dists[:, 0], dists[:, 1], dists[:, 2])
plt.savefig(self.save_dir / f'bbox_sizes_3d.png')
plt.close()
else:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(dists[:, 0], dists[:, 1])
ax.grid(True)
ax.set_title(f"Transpose forward {transpose_forward}")
plt.savefig(self.save_dir / f'bbox_sizes_2d_orig.png')
plt.close()
dists = box_size_np(self._get_scaled_boxes(
target_spacing_transposed, transpose_forward))
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(dists[:, 0], dists[:, 1])
ax.grid(True)
plt.savefig(self.save_dir / f'bbox_sizes_2d.png')
plt.close()
from typing import Sequence, List, Union, Tuple
import torch
import numpy as np
from torch import Tensor
from nndet.core.boxes import box_center
def scale_with_abs_strides(seq: Sequence[float],
strides: Sequence[Union[Sequence[Union[int, float]], Union[int, float]]],
dim_idx: int,
) -> List[Tuple[float]]:
"""
Scale values with absolute stride between feature maps
Args:
seq: sequence to scale
strides: strides to scale with.
dim_idx: dimension index for stride
"""
scaled = []
for stride in strides:
if not isinstance(stride, (float, int)):
_stride = stride[dim_idx]
else:
_stride = stride
_scaled = [i * _stride for i in seq]
scaled.append(tuple(_scaled))
return scaled
def proxy_num_boxes_in_patch(boxes: Tensor, patch_size: Sequence[int]) -> Tensor:
"""
This is just a proxy and not the exact computation
Args:
boxes: boxes
patch_size: patch size
Returns:
Tensor: count of boxes which center point is in the range of patch_size / 2
"""
patch_size = torch.tensor(patch_size, dtype=torch.float)[None, None] / 2 # [1, 1, dims]
center = box_center(boxes) # [N, dims]
center_dists = (center[None] - center[:, None]).abs() # [N, N, dims]
center_in_range = (center_dists <= patch_size).prod(dim=-1) # [N, N]
return center_in_range.sum(dim=1) # [N]
def comp_num_pool_per_axis(patch_size: Sequence[int],
max_num_pool: int,
min_feature_map_size: int) -> List[int]:
"""
Computes the maximum number of pooling operations given a minimal feature map size
and the patch size
Args:
patch_size: input patch size
max_num_pool: maximum number of pooling operations.
min_feature_map_size: Minimal size of feature map inside the bottleneck.
Returns:
List[int]: max number of pooling operations per axis
"""
network_numpool_per_axis = np.floor(
[np.log(i / min_feature_map_size) / np.log(2) for i in patch_size]).astype(np.int32)
network_numpool_per_axis = [min(i, max_num_pool) for i in network_numpool_per_axis]
return network_numpool_per_axis
def get_shape_must_be_divisible_by(num_pool_per_axis: Sequence[int]) -> np.ndarray:
"""
Returns a multiple of 2 which indicates by which factor an axis needs to
be dividable to avoid problems with upsampling
Args:
num_pool_per_axis: number of pooling operations per axis
Returns:
np.ndarray: necessary divisor of axis
"""
return 2 ** np.array(num_pool_per_axis)
def pad_shape(shape: Sequence[int], must_be_divisible_by: Sequence[int]) -> np.ndarray:
"""
Pads shape so that it is divisibly by must_be_divisible_by
Args:
shape: shape to pad
must_be_divisible_by: divisor
Returns:
np.ndarray: padded shape
"""
if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):
must_be_divisible_by = [must_be_divisible_by] * len(shape)
else:
assert len(must_be_divisible_by) == len(shape)
new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i]
for i in range(len(shape))]
for i in range(len(shape)):
if shape[i] % must_be_divisible_by[i] == 0:
new_shp[i] -= must_be_divisible_by[i]
new_shp = np.array(new_shp).astype(np.int32)
return new_shp
def scale_with_abs_strides(seq: Sequence[float],
strides: Sequence[Union[Sequence[Union[int, float]], Union[int, float]]],
dim_idx: int,
) -> List[Tuple[float]]:
"""
Scale values with absolute stride between feature maps
Args:
seq: sequence to scale
strides: strides to scale with.
dim_idx: dimension index for stride
"""
scaled = []
for stride in strides:
if not isinstance(stride, (float, int)):
_stride = stride[dim_idx]
else:
_stride = stride
_scaled = [i * _stride for i in seq]
scaled.append(tuple(_scaled))
return scaled
def proxy_num_boxes_in_patch(boxes: Tensor, patch_size: Sequence[int]) -> Tensor:
"""
This is just a proxy and not the exact computation
Args:
boxes: boxes
patch_size: patch size
Returns:
Tensor: count of boxes which center point is in the range of patch_size / 2
"""
patch_size = torch.tensor(patch_size, dtype=torch.float)[None, None] / 2 # [1, 1, dims]
center = box_center(boxes) # [N, dims]
center_dists = (center[None] - center[:, None]).abs() # [N, N, dims]
center_in_range = (center_dists <= patch_size).prod(dim=-1) # [N, N]
return center_in_range.sum(dim=1) # [N]
def fixed_anchor_init(dim: int):
"""
Fixed anchors sizes for 2d and 3d
Args:
dim: number of dimensions
Returns:
dict: fixed params
"""
anchor_plan = {"stride": 1, "aspect_ratios": (0.5, 1, 2)}
if dim == 2:
anchor_plan["sizes"] = (32, 64, 128, 256)
else:
anchor_plan["sizes"] = ((4, 8, 16), (8, 16, 32), (16, 32, 64), (32, 64, 128))
anchor_plan["zsizes"] = ((2, 3, 4), (4, 6, 8), (8, 12, 16), (12, 24, 48))
return anchor_plan
from typing import Mapping, Type
from nndet.utils.registry import Registry
from nndet.planning.experiment.base import PlannerType, AbstractPlanner
PLANNER_REGISTRY: Mapping[str, Type[PlannerType]] = Registry()
from nndet.planning.experiment.v001 import D3V001
from __future__ import annotations
import os
from pathlib import Path
from itertools import repeat
from multiprocessing import Pool
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, Optional, List, TypeVar
import numpy as np
from loguru import logger
from nndet.io.load import load_pickle, save_pickle
from nndet.io.paths import (
get_case_ids_from_dir,
get_paths_from_splitted_dir,
)
from nndet.planning.architecture.abstract import ArchitecturePlannerType
from nndet.preprocessing.preprocessor import PreprocessorType
from nndet.planning.experiment.utils import run_create_label_preprocessed
class AbstractPlanner(ABC):
def __init__(self,
preprocessed_output_dir: os.PathLike,
):
"""
Base class for experiment planning
Args:
preprocessed_output_dir: path to directory where preprocessed
data will be saved
"""
super().__init__()
self.preprocessed_output_dir = Path(preprocessed_output_dir)
self.plan: Optional[Dict] = {}
self.transpose_forward = None
self.transpose_backward = None
self.anisotropy_threshold = 3
self.resample_anisotropy_threshold = 3
self.target_spacing_percentile = 50
self.data_properties = self.load_data_properties()
@abstractmethod
def plan_experiment(self,
model_name: str,
model_cfg: Dict,
) -> List[str]:
"""
Plan the whole experiment
Args:
model_name: name of model to plan for
model_cfg: config to initialize model for VRAM estimation
Returns:
List: identifiers of created plans
"""
raise NotImplementedError
@abstractmethod
def create_architecture_planner(self,
model_name: str,
model_cfg: dict,
mode: str,
) -> ArchitecturePlannerType:
"""
Create Architecture planner
Args:
model_name: name of model to plan for
model_cfg: config to initialize model for VRAM estimation
mode: current mode of experiment planner
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def create_preprocessor(plan: Dict) -> PreprocessorType:
"""
Create Preprocessor
"""
raise NotImplementedError
@abstractmethod
def determine_forward_backward_permutation(self):
"""
Permute dimensions of input. Results should be saved into
:param:`transpose_forward` and :param:`transpose_backward`
Raises:
NotImplementedError: Should be overwritten in subcalsses
"""
raise NotImplementedError
def load_data_properties(self):
"""
Load properties from analysis of dataset
Returns:
dict: loaded properties
"""
data_properties_path = self.preprocessed_output_dir / "properties" / "dataset_properties.pkl"
assert data_properties_path.is_file(), "data properties need to exist. Run data analysis first"
data_properties = load_pickle(data_properties_path)
return data_properties
def get_data_identifier(self, mode: str):
"""
By default each plan is associated with its own folder
If only the architecture changed, this can be overwritten
to use the data from a different plan (useful for dev)
"""
return f"{self.__class__.__name__}_{mode}"
def plan_base(self) -> Dict:
"""
Create the base plan
Returns:
Dict: plan with base attributes
`target_spacing`: target to resample data
`normalization_schemes` normalization type for each modality
`use_mask_for_norm`: use mask for norm
`anisotropy_threshold`: threshold used to trigger anisotropy
settings
`resample_anisotropy_threshold`: threshold to trigger different
resampling schemes
`target_spacing_percentile`: target spacing percentile
used to create target spacing
`dim`: dimensionality of data (2 or 3)
`transpose_forward`: transpose forward order
`transpose_backward`: transpose back order
`list_of_npz_files`: files used to preprocessing
"""
use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
logger.info(f"Are we using the nonzero maks for normalization? {use_nonzero_mask_for_normalization}")
target_spacing = self.determine_target_spacing()
logger.info(f"Base target spacing is {target_spacing}")
self.determine_forward_backward_permutation()
normalization_schemes = self.determine_normalization()
logger.info(f"Normalization schemes {normalization_schemes}")
plan = {
'target_spacing': target_spacing,
'normalization_schemes': normalization_schemes,
'use_mask_for_norm': use_nonzero_mask_for_normalization,
'anisotropy_threshold': self.anisotropy_threshold,
'resample_anisotropy_threshold': self.resample_anisotropy_threshold,
'target_spacing_percentile': self.target_spacing_percentile,
'dim': self.data_properties['dim'],
"num_modalities": len(list(self.data_properties['modalities'].keys())),
"all_classes": self.data_properties['all_classes'],
"num_classes": len(self.data_properties['all_classes']),
'transpose_forward': self.transpose_forward,
'transpose_backward': self.transpose_backward,
'dataset_properties': self.data_properties,
"planner_id": self.__class__.__name__,
}
return plan
def plan_base_stage(self,
base_plan: Dict,
model_name: str,
model_cfg: dict,
):
"""
Plan the first stage of training
Args:
base_plan: basic plan
model_name: name of model to plan for
model_cfg: config to initialize model for VRAM estimation
Returns:
dict: properties of stage
`patch_size`
`batch_size`
`architecture` (dict): kwargs for architecture
`current_spacing`
`original_spacing`
`median_shape_transposed`
`do_dummy_2D_data_aug`
"""
target_spacing = base_plan['target_spacing']
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
median_shape = np.median(np.vstack(new_shapes), 0)
logger.info(f"The median shape of the dataset is {median_shape}")
max_shape = np.max(np.vstack(new_shapes), 0)
logger.info(f"The max shape in the dataset is {max_shape}")
min_shape = np.min(np.vstack(new_shapes), 0)
logger.info(f"The min shape in the dataset is {min_shape}")
target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
median_shape_transposed = np.array(median_shape)[self.transpose_forward]
logger.info(f"The transposed median shape of the dataset is {median_shape_transposed}")
architecture_planner = self.create_architecture_planner(
model_name=model_name,
model_cfg=model_cfg,
mode=base_plan["mode"],
)
architecture_plan = architecture_planner.plan(
target_spacing_transposed=target_spacing_transposed,
median_shape_transposed=median_shape_transposed,
transpose_forward=self.transpose_forward,
mode=base_plan["mode"],
)
patch_size = architecture_plan["patch_size"]
do_dummy_2d_data_aug = (max(patch_size) / min(patch_size)) > self.anisotropy_threshold
base_plan.update(architecture_plan)
base_plan["target_spacing_transposed"] = target_spacing_transposed
base_plan["median_shape_transposed"] = median_shape_transposed
base_plan["do_dummy_2D_data_aug"] = do_dummy_2d_data_aug
return base_plan
def determine_target_spacing(self) -> np.ndarray:
"""
Determine target spacing.
Same as nnUNet v21
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
"""
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
target_size_mm = np.array(target) * np.array(target_size)
# we need to identify datasets for which a different target spacing could be beneficial. These datasets have
# the following properties:
# - one axis which much lower resolution than the others
# - the lowres axis has much less voxels than the others
# - (the size in mm of the lowres axis is also reduced)
worst_spacing_axis = np.argmax(target)
other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
other_spacings = [target[i] for i in other_axes]
other_sizes = [target_size[i] for i in other_axes]
has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * min(other_spacings))
has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
# we don't use the last one for now
# median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
if has_aniso_spacing and has_aniso_voxels:
spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
# don't let the spacing of that axis get higher than the other axes
if target_spacing_of_that_axis < min(other_spacings):
target_spacing_of_that_axis = max(min(other_spacings), target_spacing_of_that_axis) + 1e-5
target[worst_spacing_axis] = target_spacing_of_that_axis
return target
def determine_postprocessing(self) -> dict:
"""
Placeholder for the future
Deprecated version returned:
'keep_only_largest_region'
'min_region_size_per_class'
'min_size_per_class'
"""
logger.warning("No planning for post-processing implemented.")
return {}
def determine_normalization(self) -> Dict[int, str]:
"""
Determine normalization scheme for data
Returns:
Dict[int, str]: integer index represents modality and string is
either `CT` or `nonCT`
"""
schemes = OrderedDict()
modalities = self.data_properties['modalities']
num_modalities = len(list(modalities.keys()))
for i in range(num_modalities):
if modalities[i] == "CT":
schemes[i] = "CT"
elif modalities[i] == "CT2":
schemes[i] = "CT2"
else:
schemes[i] = "nonCT"
return schemes
def determine_whether_to_use_mask_for_norm(self) -> Dict[int, bool]:
"""
Determine if only foreground values should be used for normalization for all modalities
Returns:
Dict[int, bool]: result for each modality
"""
# only use the nonzero mask for normalization of the cropping based on it resulted in a decrease in
# image size (this is an indication that the data is something like brats/isles and then we want to
# normalize in the brain region only)
modalities = self.data_properties['modalities']
num_modalities = len(list(modalities.keys()))
use_mask_for_norm = OrderedDict()
for i in range(num_modalities):
if "CT" in modalities[i]:
use_mask_for_norm[i] = False
else:
all_size_reductions = list(self.data_properties["size_reductions"].values())
if np.median(all_size_reductions) < 3 / 4.:
logger.info("using nonzero mask for normalization")
use_mask_for_norm[i] = True
else:
logger.info("not using nonzero mask for normalization")
use_mask_for_norm[i] = False
return use_mask_for_norm
def save_plan(self, mode: str) -> str:
"""
Save plan
Args:
mode: plan mode
Return:
str: plan identifier
"""
self.preprocessed_output_dir.mkdir(
parents=True,
exist_ok=True,
)
identifier = f"{self.__class__.__name__}_{mode}"
save_pickle(self.plan, self.preprocessed_output_dir / f"{identifier}.pkl")
return identifier
def run_preprocessing(
self,
cropped_data_dir: os.PathLike,
plan: dict,
num_processes: int,
):
"""
Runs data preprocessing
Args:
cropped_data_dir: base cropped dir
plan: plan to use for preprocessing
num_processes: number of processes to use for preprocessing
"""
preprocessor = self.create_preprocessor(plan=plan)
preprocessor.run(
target_spacings=[plan["target_spacing"]],
identifiers=[plan["data_identifier"]],
cropped_data_dir=Path(cropped_data_dir),
preprocessed_output_dir=self.preprocessed_output_dir,
num_processes=num_processes,
)
self.create_labels_tr_preprocessed(
preprocessed_plan_dir=self.preprocessed_output_dir / plan["data_identifier"],
dim=3,
num_processes=num_processes,
)
@staticmethod
def create_labels_tr_preprocessed(
preprocessed_plan_dir: Path,
dim: int,
num_processes: int = 6,
):
"""
Creates labels for visualization and analysis purposes from
preprocessed data
Args:
preprocessed_plan_dir: path to preprocessed plan dir
dim: number of spatial dimensions
num_processes: number of processed to use
"""
source_dir = preprocessed_plan_dir / "imagesTr"
target_dir = preprocessed_plan_dir / "labelsTr"
target_dir.mkdir(parents=True, exist_ok=True)
case_ids = get_case_ids_from_dir(source_dir,
remove_modality=False,
pattern="*.npz",
)
logger.info('Preparing preprocessed evaluation labels')
if num_processes > 0:
with Pool(processes=num_processes) as p:
p.starmap(run_create_label_preprocessed,
zip(repeat(source_dir),
case_ids,
repeat(dim),
repeat(target_dir),
)
)
else:
for cid in case_ids:
run_create_label_preprocessed(source_dir, cid, dim, target_dir)
@classmethod
def run_preprocessing_test(cls,
preprocessed_output_dir: os.PathLike,
splitted_4d_output_dir: os.PathLike,
plan: dict,
num_processes: int = 0,
):
"""
Run preprocessing of test data
Args:
splitted_4d_output_dir: base dir of splitted data
plan: plan to use for preprocessing
num_processes: number of processes to use for preprocessing
"""
logger.info("Running preprocessing of test cases")
splitted_4d_output_dir = Path(splitted_4d_output_dir)
target_dir = Path(preprocessed_output_dir) / plan["data_identifier"] / "imagesTs"
target_dir.mkdir(parents=True, exist_ok=True)
cases_processed = get_case_ids_from_dir(
target_dir,
remove_modality=False,
pattern="*.npz",
)
cases = get_paths_from_splitted_dir(
num_modalities=plan["num_modalities"],
splitted_4d_output_dir=splitted_4d_output_dir,
test=True,
labels=False,
remove_ids=cases_processed,
)
logger.info(f"Found {len(cases)} cases for preprocssing in {splitted_4d_output_dir} "
f"and {len(cases_processed)} alrady processed cases.")
preprocessor = cls.create_preprocessor(plan=plan)
if num_processes > 0:
with Pool(processes=num_processes) as p:
p.starmap(preprocessor.run_test,
zip(cases,
repeat(plan["target_spacing"]),
repeat(target_dir),
)
)
else:
for c in cases:
preprocessor.run_test(c, plan["target_spacing"], target_dir)
PlannerType = TypeVar('PlannerType', bound=AbstractPlanner)
import os
import numpy as np
from pathlib import Path
from loguru import logger
from itertools import repeat
from multiprocessing import Pool
from typing import Dict
from nndet.utils.itk import load_sitk_as_array
from nndet.io.load import load_json, load_pickle
from nndet.io.paths import get_case_ids_from_dir
from nndet.io.transforms.instances import (
get_bbox_np,
instances_to_segmentation_np,
)
def create_label_case(
target_dir: Path,
case_id: str,
instances: np.ndarray,
mapping: Dict[int, int],
dim: int,
) -> None:
"""
Crete labels for evaluation and analysis purposes
Args:
target_dir: target dir to save labels
case_id: case identifier
instances: instance segmentation
mapping: map each instance id to a class (classes start from 0)
dim: spatial dimensions
"""
instances_save_path = target_dir / f"{case_id}_instances_gt.npz"
boxes_save_path = target_dir / f"{case_id}_boxes_gt.npz"
seg_save_path = target_dir / f"{case_id}_seg_gt.npz"
if instances_save_path.is_file() and boxes_save_path.is_file() and seg_save_path.is_file():
logger.warning(f"Skipping prepare label {case_id} because it already exists")
else:
logger.info(f"Preparing label {case_id}")
if instances.ndim == dim:
instances = instances[None]
np.savez_compressed(str(instances_save_path),
instances=instances, mapping=mapping,
)
res = get_bbox_np(instances, mapping, dim=dim)
np.savez_compressed(str(boxes_save_path), **res)
seg = instances_to_segmentation_np(instances, mapping)
np.savez_compressed(str(seg_save_path), seg=seg)
def create_labels(
preprocessed_output_dir: os.PathLike,
source_dir: os.PathLike,
num_processes: int = 6,
):
"""
Creates labels for visualization and analysis purposes from raw labels
Prepares: instance segmentation, bounding boxes, semantic segmentation
Args:
source_dir: base dir which containes labelsTr/labelsTs
dim: number of spatial dimensions
num_processes: number of processed to use
"""
source_dir = Path(source_dir)
for postfix in ["Tr", "Ts"]:
if (source_label_dir := source_dir / f"labels{postfix}").is_dir():
logger.info(f'Preparing {postfix} evaluation labels')
target_dir = Path(preprocessed_output_dir) / f"labels{postfix}"
target_dir.mkdir(parents=True, exist_ok=True)
case_ids = get_case_ids_from_dir(source_label_dir,
remove_modality=False,
pattern="*.json",
)
if num_processes > 0:
with Pool(processes=num_processes) as p:
p.starmap(run_create_label,
zip(repeat(source_label_dir),
case_ids,
repeat(3),
repeat(target_dir),
)
)
else:
for cid in case_ids:
run_create_label(source_label_dir, cid, 3, target_dir)
def run_create_label(source_label_dir: Path,
case_id: str,
dim: int,
target_dir: Path,
):
"""
Helper to run preparation with multiprocessing
Args:
source_label_dir: directory with labels
case_id: case id to process
dim: number of spatial dimensions
target_dir: directory to save results
"""
instances = load_sitk_as_array(source_label_dir / f"{case_id}.nii.gz")[0]
properties = load_json(source_label_dir / f"{case_id}.json")
if instances.ndim == dim:
instances = instances[None]
instances = instances.astype(np.int32)
mapping = {int(key): int(item) for key, item in properties["instances"].items()}
create_label_case(
target_dir=target_dir,
case_id=case_id,
instances=instances,
mapping=mapping,
dim=dim,
)
def run_create_label_preprocessed(
source_dir: Path,
case_id: str,
dim: int,
target_dir: Path,
):
"""
Helper to run preparation with multiprocessing
Args:
source_dir: directory with labels
case_id: case id to process
dim: number of spatial dimensions
target_dir: directory to save results
"""
instances = np.load(str(source_dir / f"{case_id}.npz"), mmap_mode="r")["seg"]
properties = load_pickle(source_dir / f"{case_id}.pkl")
mapping = {int(key): int(item) for key, item in properties["instances"].items()}
create_label_case(
target_dir=target_dir,
case_id=case_id,
instances=instances,
mapping=mapping,
dim=dim,
)
from typing import Dict, Optional, List
import numpy as np
from nndet.ptmodule import MODULE_REGISTRY
from nndet.planning.experiment import PLANNER_REGISTRY, AbstractPlanner
from nndet.planning.estimator import MemoryEstimatorDetection
from nndet.planning.architecture.boxes import BoxC002
from nndet.preprocessing.preprocessor import GenericPreprocessor
@PLANNER_REGISTRY.register
class D3V001(AbstractPlanner):
def plan_experiment(self,
model_name: str,
model_cfg: Dict,
) -> List[str]:
"""
Plan the whole experiment (currently only one stage is supported)
(uses :func:`self.save_plans()` to save the results)
Args:
model_name: name of model to plan for
model_cfg: config to initialize model for VRAM estimation
Returns:
List: identifiers of created plans
"""
identifiers = []
base_plan = self.plan_base()
base_plan["postprocessing"] = self.determine_postprocessing()
base_plan["mode"] = "3d"
base_plan["data_identifier"] = self.get_data_identifier(mode=base_plan["mode"])
base_plan["network_dim"] = 3
base_plan["dataloader_kwargs"] = {}
self.plan = self.plan_base_stage(base_plan,
model_name=model_name,
model_cfg=model_cfg,
)
identifiers.append(self.save_plan(mode=base_plan["mode"]))
return identifiers
def create_architecture_planner(self,
model_name: str,
model_cfg: dict,
mode: str,
) -> BoxC002:
"""
Create Architecture planner
"""
estimator = MemoryEstimatorDetection()
architecture_planner = BoxC002(
preprocessed_output_dir=self.preprocessed_output_dir,
save_dir=self.preprocessed_output_dir / "analysis" / f"{self.__class__.__name__}_{mode}",
estimator=estimator,
network_cls=MODULE_REGISTRY.get(model_name),
model_cfg=model_cfg,
)
return architecture_planner
@staticmethod
def create_preprocessor(plan: Dict) -> GenericPreprocessor:
"""
Create Preprocessor
"""
preprocessor = GenericPreprocessor(
norm_scheme_per_modality=plan['normalization_schemes'],
use_mask_for_norm=plan['use_mask_for_norm'],
transpose_forward=plan['transpose_forward'],
intensity_properties=plan['dataset_properties']['intensity_properties'],
resample_anisotropy_threshold=plan['resample_anisotropy_threshold'],
)
return preprocessor
def determine_forward_backward_permutation(self):
"""
Determine position of z direction (absolute position is defined by z_first)
Result is
saved into :param:`transpose_forward` and :param:`transpose_backward`
"""
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
target_spacing = self.determine_target_spacing()
new_sizes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
dims = len(target_spacing)
max_spacing_axis = np.argmax(target_spacing)
remaining_axes = [i for i in list(range(dims)) if i != max_spacing_axis]
# self.transpose_forward = remaining_axes + [max_spacing_axis] # y, x, z
self.transpose_forward = [max_spacing_axis] + remaining_axes # z, y, x
self.transpose_backward = [np.argwhere(np.array(
self.transpose_forward) == i)[0][0] for i in range(dims)]
from nndet.preprocessing.crop import *
from nndet.preprocessing.preprocessor import *
......@@ -21,7 +21,7 @@ from loguru import logger
from abc import ABC, abstractmethod
from multiprocessing import Pool
from pathlib import Path
from typing import Dict, Sequence, List, Tuple, Union
from typing import Dict, Sequence, List, Tuple, TypeVar, Union
from itertools import repeat
from nndet.io.transforms.instances import instances_to_boxes_np
......@@ -639,3 +639,6 @@ class GenericPreprocessor:
seg=seg,
)
return data.astype(np.float32), seg.astype(np.int32), properties
PreprocessorType = TypeVar('PreprocessorType', bound=AbstractPreprocessor)
......@@ -3,4 +3,5 @@ from nndet.utils.registry import Registry
from nndet.ptmodule.base_module import LightningBaseModule
MODULE_REGISTRY: Mapping[str, Type[LightningBaseModule]] = Registry()
# register modules
from nndet.ptmodule.retinaunet import *
from nndet.ptmodule.retinaunet.base import RetinaUNetModule
from nndet.ptmodule.retinaunet.v001 import RetinaUNetV001
from nndet.ptmodule.retinaunet.c010 import RetinaUNetC010
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment