Commit 2e2bd60b authored by mibaumgartner's avatar mibaumgartner
Browse files

add refactored low res trigger

parent 45543128
......@@ -438,7 +438,7 @@ In many cases this limitation can be circumvented by converting the bounding box
<summary>Mask RCNN and 2D Data sets</summary>
<br>
2D data sets and Mask R-CNN are not supported in the first release.
We hope to provide these sometime in the future.
We hope to provide these in the future.
</details>
<details close>
......
......@@ -4,11 +4,9 @@ defaults:
module: RetinaUNetV001
predictor: BoxPredictorSelective
plan: D3V001_3d
planners:
2d: [D2C002]
3d: [D3V001] # [D3C002LR15, D3C002LR20] [D3C002NR, D3C002RibFrac] [D2C002, D3C002]
plan: D3V001_3d # plan used for training
planner: D3V001 # planner used for preprocessing
augment_cfg:
augmentation: ${augmentation}
......
......@@ -4,11 +4,9 @@ defaults:
module: RetinaUNetV001
predictor: BoxPredictorSelective
plan: D3V001_3d
planners:
2d: [D2C002]
3d: [D3V001] # [D3C002LR15, D3C002LR20] [D3C002NR, D3C002RibFrac] [D2C002, D3C002]
plan: D3V001_3d # plan used for training
planner: D3V001 # planner used for preprocessing
augment_cfg:
augmentation: ${augmentation}
......
......@@ -171,3 +171,7 @@ def fixed_anchor_init(dim: int):
anchor_plan["sizes"] = ((4, 8, 16), (8, 16, 32), (16, 32, 64), (32, 64, 128))
anchor_plan["zsizes"] = ((2, 3, 4), (4, 6, 8), (8, 12, 16), (12, 24, 48))
return anchor_plan
def concatenate_property_boxes(all_boxes: Sequence[np.ndarray]) -> np.ndarray:
return np.concatenate([b for b in all_boxes if not isinstance(b, list) and b.size > 0], axis=0)
......@@ -35,8 +35,6 @@ class AbstractPlanner(ABC):
super().__init__()
self.preprocessed_output_dir = Path(preprocessed_output_dir)
self.plan: Optional[Dict] = {}
self.transpose_forward = None
self.transpose_backward = None
......@@ -88,16 +86,34 @@ class AbstractPlanner(ABC):
raise NotImplementedError
@abstractmethod
def determine_forward_backward_permutation(self):
def determine_forward_backward_permutation(self, mode: str):
"""
Permute dimensions of input. Results should be saved into
:param:`transpose_forward` and :param:`transpose_backward`
Args:
mode: define current operation mode. Typically one of
'2d' | '3d' | '3dlr1'
Raises:
NotImplementedError: Should be overwritten in subcalsses
"""
raise NotImplementedError
@abstractmethod
def determine_target_spacing(self, mode: str) -> np.ndarray:
"""
Determine target spacing.
Args:
mode: define current operation mode. Typically one of
'2d' | '3d' | '3dlr1'
Same as nnUNet v21
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
"""
raise NotImplementedError
def load_data_properties(self):
"""
Load properties from analysis of dataset
......@@ -119,12 +135,17 @@ class AbstractPlanner(ABC):
"""
return f"{self.__class__.__name__}_{mode}"
def plan_base(self) -> Dict:
def plan_base(self, mode: str) -> Dict:
"""
Create the base plan
Args:
mode: define current operation mode. Typically one of
'2d' | '3d' | '3dlr1'
Returns:
Dict: plan with base attributes
'mode': selected mode for plan
`target_spacing`: target to resample data
`normalization_schemes` normalization type for each modality
`use_mask_for_norm`: use mask for norm
......@@ -141,13 +162,14 @@ class AbstractPlanner(ABC):
"""
use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
logger.info(f"Are we using the nonzero maks for normalization? {use_nonzero_mask_for_normalization}")
target_spacing = self.determine_target_spacing()
target_spacing = self.determine_target_spacing(mode=mode)
logger.info(f"Base target spacing is {target_spacing}")
self.determine_forward_backward_permutation()
self.determine_forward_backward_permutation(mode=mode)
normalization_schemes = self.determine_normalization()
logger.info(f"Normalization schemes {normalization_schemes}")
plan = {
'mode': mode,
'target_spacing': target_spacing,
'normalization_schemes': normalization_schemes,
'use_mask_for_norm': use_nonzero_mask_for_normalization,
......@@ -225,48 +247,14 @@ class AbstractPlanner(ABC):
base_plan["do_dummy_2D_data_aug"] = do_dummy_2d_data_aug
return base_plan
def determine_target_spacing(self) -> np.ndarray:
"""
Determine target spacing.
Same as nnUNet v21
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
"""
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
target_size_mm = np.array(target) * np.array(target_size)
# we need to identify datasets for which a different target spacing could be beneficial. These datasets have
# the following properties:
# - one axis which much lower resolution than the others
# - the lowres axis has much less voxels than the others
# - (the size in mm of the lowres axis is also reduced)
worst_spacing_axis = np.argmax(target)
other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
other_spacings = [target[i] for i in other_axes]
other_sizes = [target_size[i] for i in other_axes]
has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * min(other_spacings))
has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
# we don't use the last one for now
# median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
if has_aniso_spacing and has_aniso_voxels:
spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
# don't let the spacing of that axis get higher than the other axes
if target_spacing_of_that_axis < min(other_spacings):
target_spacing_of_that_axis = max(min(other_spacings), target_spacing_of_that_axis) + 1e-5
target[worst_spacing_axis] = target_spacing_of_that_axis
return target
def determine_postprocessing(self) -> dict:
def determine_postprocessing(self, mode: str) -> dict:
"""
Placeholder for the future
Args:
mode: define current operation mode. Typically one of
'2d' | '3d' | '3dlr1'
Deprecated version returned:
'keep_only_largest_region'
'min_region_size_per_class'
......@@ -324,7 +312,7 @@ class AbstractPlanner(ABC):
use_mask_for_norm[i] = False
return use_mask_for_norm
def save_plan(self, mode: str) -> str:
def save_plan(self, plan: dict, mode: str) -> str:
"""
Save plan
......@@ -339,7 +327,7 @@ class AbstractPlanner(ABC):
exist_ok=True,
)
identifier = f"{self.__class__.__name__}_{mode}"
save_pickle(self.plan, self.preprocessed_output_dir / f"{identifier}.pkl")
save_pickle(plan, self.preprocessed_output_dir / f"{identifier}.pkl")
return identifier
def run_preprocessing(
......
import os
from nndet.core.boxes.ops_np import box_size_np
import numpy as np
from pathlib import Path
......
from typing import Dict, Optional, List
from typing import Dict, List, Sequence
import numpy as np
from loguru import logger
from nndet.ptmodule import MODULE_REGISTRY
from nndet.planning.experiment import PLANNER_REGISTRY, AbstractPlanner
from nndet.planning.estimator import MemoryEstimatorDetection
from nndet.planning.architecture.boxes import BoxC002
from nndet.preprocessing.preprocessor import GenericPreprocessor
from nndet.core.boxes.ops_np import box_size_np
from nndet.planning.architecture.boxes.utils import concatenate_property_boxes
@PLANNER_REGISTRY.register
......@@ -27,19 +31,43 @@ class D3V001(AbstractPlanner):
List: identifiers of created plans
"""
identifiers = []
base_plan = self.plan_base()
base_plan["postprocessing"] = self.determine_postprocessing()
base_plan["mode"] = "3d"
base_plan["data_identifier"] = self.get_data_identifier(mode=base_plan["mode"])
base_plan["network_dim"] = 3
base_plan["dataloader_kwargs"] = {}
self.plan = self.plan_base_stage(base_plan,
model_name=model_name,
model_cfg=model_cfg,
)
identifiers.append(self.save_plan(mode=base_plan["mode"]))
# create full resolution 3d plan
mode = "3d"
plan_3d = self.plan_base(mode=mode)
plan_3d["network_dim"] = 3
plan_3d["dataloader_kwargs"] = {}
plan_3d["data_identifier"] = self.get_data_identifier(mode=mode)
plan_3d["postprocessing"] = self.determine_postprocessing(mode=mode)
plan_3d = self.plan_base_stage(
plan_3d,
model_name=model_name,
model_cfg=model_cfg,
)
# determine if additional low res model needs to be trained
plan_3d["trigger_lr1"] = self.trigger_low_res_model(
prev_res_patch_size=plan_3d["patch_size"],
transpose_forward=plan_3d["transpose_forward"],
)
identifiers.append(self.save_plan(plan=plan_3d, mode=plan_3d["mode"]))
if plan_3d["trigger_lr1"]:
logger.info("Triggered Low Resolution Model")
mode = "3dlr1"
plan_3dlr1 = self.plan_base(mode=mode)
plan_3dlr1["network_dim"] = 3
plan_3dlr1["dataloader_kwargs"] = {}
plan_3dlr1["data_identifier"] = self.get_data_identifier(mode=mode)
plan_3dlr1["postprocessing"] = self.determine_postprocessing(mode=mode)
self.plan = self.plan_base_stage(
plan_3dlr1,
model_name=model_name,
model_cfg=model_cfg,
)
identifiers.append(self.save_plan(plan=plan_3dlr1, mode=plan_3dlr1["mode"]))
return identifiers
def create_architecture_planner(self,
......@@ -74,7 +102,7 @@ class D3V001(AbstractPlanner):
)
return preprocessor
def determine_forward_backward_permutation(self):
def determine_forward_backward_permutation(self, mode: str):
"""
Determine position of z direction (absolute position is defined by z_first)
Result is
......@@ -83,7 +111,7 @@ class D3V001(AbstractPlanner):
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
target_spacing = self.determine_target_spacing()
target_spacing = self.determine_target_spacing(mode=mode)
new_sizes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
dims = len(target_spacing)
......@@ -93,3 +121,90 @@ class D3V001(AbstractPlanner):
self.transpose_forward = [max_spacing_axis] + remaining_axes # z, y, x
self.transpose_backward = [np.argwhere(np.array(
self.transpose_forward) == i)[0][0] for i in range(dims)]
def determine_target_spacing(self, mode: str) -> np.ndarray:
"""
Determine target spacing
Args:
mode: Current planning mode. Typically one of '2d' | '3d' | '3dlr1'
Raises:
RuntimeError: not supported mode (supported are 2d, 3d, 3dlrX)
Returns:
np.ndarray: target spacing
"""
base_target_spacing = self._target_spacing_base()
if mode == "3d" or mode == "2d":
target_spacing = base_target_spacing
else:
if not "lr" in mode:
raise RuntimeError(f"Mode {mode} is not supported for target spacing.")
downscale = int(mode.split('lr')[-1])
target_spacing = base_target_spacing * (2 ** downscale)
return target_spacing
def _target_spacing_base(self) -> np.ndarray:
"""
Determine target spacing.
Same as nnUNet v21
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
"""
spacings = self.data_properties['all_spacings']
sizes = self.data_properties['all_sizes']
target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
target_size_mm = np.array(target) * np.array(target_size)
# we need to identify datasets for which a different target spacing could be beneficial. These datasets have
# the following properties:
# - one axis which much lower resolution than the others
# - the lowres axis has much less voxels than the others
# - (the size in mm of the lowres axis is also reduced)
worst_spacing_axis = np.argmax(target)
other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
other_spacings = [target[i] for i in other_axes]
other_sizes = [target_size[i] for i in other_axes]
has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * min(other_spacings))
has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
# we don't use the last one for now
# median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
if has_aniso_spacing and has_aniso_voxels:
spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
# don't let the spacing of that axis get higher than the other axes
if target_spacing_of_that_axis < min(other_spacings):
target_spacing_of_that_axis = max(min(other_spacings), target_spacing_of_that_axis) + 1e-5
target[worst_spacing_axis] = target_spacing_of_that_axis
return target
def trigger_low_res_model(
self,
prev_res_patch_size: Sequence[int],
transpose_forward: Sequence[int],
) -> bool:
"""
Trigger additional low resolution model
Args:
prev_res_patch_size: patch size of previous stage
Returns:
bool: If True, trigger a low resolution model. If False, current
resolution is ok.
"""
all_boxes = [case["boxes"] for case_id, case in \
self.data_properties["instance_props_per_patient"].items()]
all_boxes = concatenate_property_boxes(all_boxes)
object_size = np.percentile(box_size_np(all_boxes), 99.5, axis=0)
object_size = object_size[list(transpose_forward)]
if (np.asarray(prev_res_patch_size) < object_size).any():
return True
else:
return False
......@@ -38,6 +38,7 @@ from nndet.core.boxes.sampler import HardNegativeSamplerBatched
from nndet.core.boxes.coder import CoderType, BoxCoderND
from nndet.core.boxes.anchors import get_anchor_generator
from nndet.core.boxes.ops import box_iou
from nndet.core.boxes.anchors import AnchorGeneratorType
from nndet.ptmodule.base_module import LightningBaseModuleSWA, LightningBaseModule
......@@ -509,7 +510,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
cls,
plan_arch: dict,
model_cfg: dict,
anchor_generator: AnchorType,
anchor_generator: AnchorGeneratorType,
) -> ClassifierType:
"""
Build classification subnetwork for detection head
......@@ -543,7 +544,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
cls,
plan_arch: dict,
model_cfg: dict,
anchor_generator: AnchorType,
anchor_generator: AnchorGeneratorType,
) -> RegressorType:
"""
Build regression subnetwork for detection head
......
......@@ -202,7 +202,7 @@ def run_planning_and_process(
splitted_4d_output_dir: Path,
cropped_output_dir: Path,
preprocessed_output_dir: Path,
planners: Dict[str, Sequence[str]],
planner_name: str,
dim: int,
model_name: str,
model_cfg: Dict,
......@@ -216,8 +216,7 @@ def run_planning_and_process(
splitted_4d_output_dir: base dir of splitted data
cropped_output_dir: base dir of cropped data
preprocessed_output_dir: base dir of preprocessed data
planners: define planners for
the needed dimension
planner_name: planner name
dim: number of spatial dimensions
model_name: name of model to run planning for
model_cfg: hyperparameters of model (used during planning to
......@@ -225,51 +224,49 @@ def run_planning_and_process(
num_processes: number of processes to use for preprocessing
run_preprocessing: Preprocess and check data. Defaults to True.
"""
selected_planners = planners[f"{dim}d"]
for planner_name in selected_planners:
planner_cls = PLANNER_REGISTRY.get(planner_name)
planner = planner_cls(
preprocessed_output_dir=preprocessed_output_dir
)
plan_identifiers = planner.plan_experiment(
model_name=model_name,
model_cfg=model_cfg,
)
if run_preprocessing:
for plan_id in plan_identifiers:
plan = load_pickle(preprocessed_output_dir / plan_id)
planner_cls = PLANNER_REGISTRY.get(planner_name)
planner = planner_cls(
preprocessed_output_dir=preprocessed_output_dir
)
plan_identifiers = planner.plan_experiment(
model_name=model_name,
model_cfg=model_cfg,
)
if run_preprocessing:
for plan_id in plan_identifiers:
plan = load_pickle(preprocessed_output_dir / plan_id)
planner.run_preprocessing(
cropped_data_dir=cropped_output_dir / "imagesTr",
plan=plan,
num_processes=num_processes,
)
case_ids_failed, result_check = run_check(
data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
remove=True,
processes=num_processes
)
# delete and rerun corrupted cases
if not result_check:
logger.warning(f"{plan_id} check failed: There are corrupted files {case_ids_failed}!!!!"
f"Running preprocessing of those cases without multiprocessing.")
planner.run_preprocessing(
cropped_data_dir=cropped_output_dir / "imagesTr",
plan=plan,
num_processes=num_processes,
)
num_processes=0,
)
case_ids_failed, result_check = run_check(
data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
remove=True,
processes=num_processes
remove=False,
processes=0
)
# delete and rerun corrupted cases
if not result_check:
logger.warning(f"{plan_id} check failed: There are corrupted files {case_ids_failed}!!!!"
f"Running preprocessing of those cases without multiprocessing.")
planner.run_preprocessing(
cropped_data_dir=cropped_output_dir / "imagesTr",
plan=plan,
num_processes=0,
)
case_ids_failed, result_check = run_check(
data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
remove=False,
processes=0
)
if not result_check:
logger.error(f"Could not fix corrupted files {case_ids_failed}!")
raise RuntimeError("Found corrupted files, check logs!")
else:
logger.info("Fixed corrupted files.")
logger.error(f"Could not fix corrupted files {case_ids_failed}!")
raise RuntimeError("Found corrupted files, check logs!")
else:
logger.info(f"{plan_id} check successful: Loading check completed")
logger.info("Fixed corrupted files.")
else:
logger.info(f"{plan_id} check successful: Loading check completed")
if run_preprocessing:
create_labels(
......@@ -406,7 +403,7 @@ def run(cfg, instances_from_seg):
splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
planners=cfg["planners"],
planner_name=cfg["planner"],
dim=data_info["dim"],
model_name=cfg["module"],
model_cfg=cfg["model_cfg"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment