Commit fade0f46 authored by mibaumgartner's avatar mibaumgartner
Browse files

clean up and improved multi class import

parent 537a3032
......@@ -15,10 +15,10 @@
| Model | LIDC | RibFrac | CADA | Kits19 |
|:-----------:|:--------------------:|:--------:|:--------:|:--------:|
| nnDetection | 0.605 | 0.765 | 0.924 | 0.923 |
| nnUNetPlus | 0.385<sup>*</sup> | 0.700 | 0.955 | 0.935 |
| nnUNetBasic | 0.346<sup>*</sup> | 0.667 | 0.930 | 0.908 |
| nnUNetPlus | 0.439<sup>*</sup> | 0.700 | 0.955 | 0.935 |
| nnUNetBasic | 0.411<sup>*</sup> | 0.667 | 0.930 | 0.908 |
<sup>*</sup> results with corrected numerical values in softdice loss. Out of the box results: nnUNetPlus 0.304 and nnUNetBasic 0.232
<sup>*</sup> results with corrected numerical values in softdice loss and improved multi-class import.
&nbsp;
......@@ -27,21 +27,24 @@
5 Fold Cross Validation
| Model | ADAM | ProstateX | Pancreas | Hepatic Vessel | Colon | Liver |
|:-----------:|:--------:|:----------:|:--------:|:----------------:|:-----:|:-----:|
|:-----------:|:--------:|:----------------------:|:--------:|:----------------:|:-----:|:-----:|
| nnDetection | 0.780 | 0.300 | 0.766 | 0.727 | 0.662 | 0.628 |
| nnUNetPlus | 0.720 | 0.197 | 0.721 | 0.721 | 0.579 | 0.678 |
| nnUNetBasic | 0.657 | 0.204 | 0.691 | 0.699 | 0.509 | 0.567 |
| nnUNetPlus | 0.720 | 0.220<sup>*</sup> | 0.721 | 0.721 | 0.579 | 0.678 |
| nnUNetBasic | 0.657 | 0.202<sup>*</sup> | 0.691 | 0.699 | 0.509 | 0.567 |
<sup>*</sup>improved multi-class import
&nbsp;
Test Split
| Model | ProstateX | Pancreas | Hepatic Vessel | Colon | Liver |
|:-----------:|:----------:|:--------:|:----------------:|:-----:|:-----:|
|:-----------:|:----------------------:|:--------:|:----------------:|:-----:|:-----:|
| nnDetection | 0.221 | 0.791 | 0.664 | 0.696 | 0.790 |
| nnUNetPlus | 0.078 | 0.704 | 0.684 | 0.731 | 0.760 |
| nnUNetPlus | 0.123<sup>*</sup> | 0.704 | 0.684 | 0.731 | 0.760 |
ADAM Results are listed under Benchmarks
<sup>*</sup>improved multi-class import
&nbsp;
......
......@@ -19,18 +19,16 @@ import numpy as np
import SimpleITK as sitk
from pathlib import Path
from typing import Dict, List, Sequence, Optional
from typing import List, Sequence
from nndet.io.paths import Pathlike
from loguru import logger
from sklearn.model_selection import train_test_split
from nndet.io.paths import Pathlike
from nndet.io.paths import get_case_ids_from_dir
from nndet.io.load import save_json
from nndet.utils.clustering import seg2instances, remove_classes, reorder_classes
__all__ = ["maybe_split_4d_nifti", "instances_from_segmentation", "sitk_copy_metadata"]
__all__ = ["maybe_split_4d_nifti"]
def maybe_split_4d_nifti(source_file: Path, output_folder: Path):
......@@ -114,99 +112,6 @@ def create_itk_image_spatial_props(
return data_itk
def sitk_copy_metadata(img_source: sitk.Image, img_target: sitk.Image) -> sitk.Image:
"""
Copy metadata (spacing, origin, direction) from source to target image
Args
img_source: source image
img_target: target image
Returns:
SimpleITK.Image: target image with copied metadata
"""
raise RuntimeError("Deprecated")
spacing = img_source.GetSpacing()
img_target.SetSpacing(spacing)
origin = img_source.GetOrigin()
img_target.SetOrigin(origin)
direction = img_source.GetDirection()
img_target.SetDirection(direction)
return img_target
def instances_from_segmentation(source_file: Path, output_folder: Path,
rm_classes: Sequence[int] = None,
ro_classes: Dict[int, int] = None,
subtract_one_of_classes: bool = True,
fg_vs_bg: bool = False,
file_name: Optional[str] = None
):
"""
1. Optionally removes classes from the segmentation (
e.g. organ segmentation's which are not useful for detection)
2. Optionally reorders the segmentation indices
3. Converts semantic segmentation to instance segmentation's via
connected components
Args:
source_file: path to semantic segmentation file
output_folder: folder where processed file will be saved
rm_classes: classes to remove from semantic segmentation
ro_classes: reorder classes before instances are generated
subtract_one_of_classes: subtracts one from the classes
in the instance mapping (detection networks assume
that classes start from 0)
fg_vs_bg: map all foreground classes to a single class to run
foreground vs background detection task.
file_name: name of saved file (without file type!)
"""
if subtract_one_of_classes and fg_vs_bg:
logger.info("subtract_one_of_classes will be ignored because fg_vs_bg is "
"active and all foreground classes ill be mapped to 0")
seg_itk = sitk.ReadImage(str(source_file))
seg_npy = sitk.GetArrayFromImage(seg_itk)
if rm_classes is not None:
seg_npy = remove_classes(seg_npy, rm_classes)
if ro_classes is not None:
seg_npy = reorder_classes(seg_npy, ro_classes)
instances, instance_classes = seg2instances(seg_npy)
if fg_vs_bg:
num_instances_check = len(instance_classes)
seg_npy[seg_npy > 0] = 1
instances, instance_classes = seg2instances(seg_npy)
num_instances = len(instance_classes)
if num_instances != num_instances_check:
logger.warning(f"Lost instance: Found {num_instances} instances before "
f"fg_vs_bg but {num_instances_check} instances after it")
if subtract_one_of_classes:
for key in instance_classes.keys():
instance_classes[key] -= 1
if fg_vs_bg:
for key in instance_classes.keys():
instance_classes[key] = 0
seg_itk_new = sitk.GetImageFromArray(instances)
seg_itk_new = sitk_copy_metadata(seg_itk, seg_itk_new)
if file_name is None:
suffix_length = sum(map(len, source_file.suffixes))
file_name = source_file.name[:-suffix_length]
save_json({"instances": instance_classes}, output_folder / f"{file_name}.json")
sitk.WriteImage(seg_itk_new, str(output_folder / f"{file_name}.nii.gz"))
def create_test_split(splitted_dir: Pathlike,
num_modalities: int,
test_size: float = 0.3,
......
......@@ -22,17 +22,16 @@ from typing import Dict, Sequence, Union, Tuple, Optional
from nndet.io.transforms.instances import get_bbox_np
def seg2instances(seg: np.ndarray,
exclude_background: bool = True,
def seg_to_instances(
seg: np.ndarray,
min_num_voxel: int = 0,
) -> Tuple[np.ndarray, Dict[int, int]]:
"""
Use connected components with ones matrix to created instance from segmentation
Use connected components with ones matrix to created
instances from segmentation
Args:
seg: semantic segmentation [spatial dims]
exclude_background: skips background class for the mapping
from instances to classes
min_num_voxel: minimum number of voxels of an instance
Returns:
......@@ -40,14 +39,62 @@ def seg2instances(seg: np.ndarray,
Dict[int, int]: mapping from instances to classes
"""
structure = np.ones([3] * seg.ndim)
instances_temp, _ = label(seg, structure=structure)
unique_classes = np.unique(seg)
unique_classes = unique_classes[unique_classes > 0]
instances = np.zeros_like(seg)
instance_classes = {}
i = 1
for uc in unique_classes:
binary_class_mask = (seg == uc)
instances_temp, _ = label(binary_class_mask, structure=structure)
instance_ids = np.unique(instances_temp)
instance_ids = instance_ids[instance_ids > 0]
for iid in instance_ids:
instance_binary_mask = instances_temp == iid
if min_num_voxel > 0:
if instance_binary_mask.sum() < min_num_voxel: # remove small instances
continue
instances[instance_binary_mask] = i # save instance to final mask
instance_classes[int(i)] = uc
i = i + 1 # bump instance index
return instances, instance_classes
def seg_to_instances_voted(
seg: np.ndarray,
min_num_voxel: int = 0,
) -> Tuple[np.ndarray, Dict[int, int]]:
"""
Conntected component analysis is performed on foreground
(independent of exact class) and the final class
is determined via majority voting.
Args:
seg: semantic segmentation [spatial dims]
min_num_voxel: minimum number of voxels of an instance
Returns:
np.ndarray: instance segmentation
Dict[int, int]: mapping from instances to classes
"""
structure = np.ones([3] * seg.ndim)
binary_fg_mask = (seg > 0).astype(int)
instances_temp, _ = label(binary_fg_mask, structure=structure)
instance_ids = np.unique(instances_temp)
if exclude_background:
instance_ids = instance_ids[instance_ids > 0]
instances = np.zeros_like(seg)
instance_classes = {}
instances = np.zeros_like(instances_temp)
i = 1
for iid in instance_ids:
instance_binary_mask = instances_temp == iid
......@@ -57,15 +104,24 @@ def seg2instances(seg: np.ndarray,
continue
instances[instance_binary_mask] = i # save instance to final mask
single_idx = np.argwhere(instance_binary_mask)[0] # select semantic class
semantic_class = int(seg[tuple(single_idx)])
instance_classes[int(i)] = semantic_class # save class
cls_id, cls_count = np.unique(
seg[instance_binary_mask], return_counts=True) # count classes in region
majority_voted_class = cls_id[np.argmax(cls_count)] # select class with most votes
assert 0 not in cls_id
assert majority_voted_class > 0
instance_classes[int(i)] = majority_voted_class
i = i + 1 # bump instance index
return instances, instance_classes
def remove_classes(seg: np.ndarray, rm_classes: Sequence[int], classes: Dict[int, int] = None,
background: int = 0) -> Union[np.ndarray, Tuple[np.ndarray, Dict[int, int]]]:
def remove_classes(
seg: np.ndarray,
rm_classes: Sequence[int],
classes: Dict[int, int] = None,
background: int = 0,
) -> Union[np.ndarray, Tuple[np.ndarray, Dict[int, int]]]:
"""
Remove classes from segmentation (also works on instances
but instance ids may not be consecutive anymore)
......@@ -90,7 +146,10 @@ def remove_classes(seg: np.ndarray, rm_classes: Sequence[int], classes: Dict[int
return seg, classes
def reorder_classes(seg: np.ndarray, class_mapping: Dict[int, int]) -> np.ndarray:
def reorder_classes(
seg: np.ndarray,
class_mapping: Dict[int, int],
) -> np.ndarray:
"""
Reorders classes in segmentation
......@@ -106,7 +165,8 @@ def reorder_classes(seg: np.ndarray, class_mapping: Dict[int, int]) -> np.ndarra
return seg
def compute_score_from_seg(instances: np.ndarray,
def compute_score_from_seg(
instances: np.ndarray,
instance_classes: Dict[int, int],
probs: np.ndarray,
aggregation: str = "max",
......@@ -148,7 +208,8 @@ def compute_score_from_seg(instances: np.ndarray,
return np.asarray(instance_scores)
def instance_results_from_seg(probs: np.ndarray,
def softmax_to_instances(
probs: np.ndarray,
aggregation: str,
stuff: Optional[Sequence[int]] = None,
min_num_voxel: int = 0,
......@@ -178,12 +239,16 @@ def instance_results_from_seg(probs: np.ndarray,
`pred_labels`: predicted class for each instance/box
`pred_scores`: predicted score for each instance/box
"""
if probs.shape[0] < 2:
raise ValueError("softmax_to_instances only works for softmax probabilities")
if min_threshold is not None:
if probs.shape[0] > 2:
fg_argmax = np.argmax(probs, axis=0)
fg_mask = np.max(probs[1:], axis=0) > min_threshold
cluster_map = np.max(probs[1:], axis=0) > min_threshold
class_map = np.argmax(probs[1:], axis=0) + 1
seg = np.zeros_like(probs[0])
seg[fg_mask] = fg_argmax[fg_mask]
seg[cluster_map] = class_map[cluster_map]
else:
seg = probs[1] > min_threshold
else:
......@@ -192,12 +257,12 @@ def instance_results_from_seg(probs: np.ndarray,
if stuff is not None:
for s in stuff:
seg[seg == s] = 0
instances, instance_classes = seg2instances(seg,
exclude_background=True,
min_num_voxel=min_num_voxel,
instances, instance_classes = seg_to_instances_voted(seg, min_num_voxel=min_num_voxel)
instance_scores = compute_score_from_seg(
instances, instance_classes, probs, aggregation=aggregation,
)
instance_scores = compute_score_from_seg(instances, instance_classes, probs,
aggregation=aggregation)
instance_classes = {int(key): int(item) - 1 for key, item in instance_classes.items()}
tmp = get_bbox_np(instances[None], instance_classes)
instance_boxes = tmp["boxes"]
......
......@@ -28,7 +28,6 @@ from typing import Sequence, Union
from nndet.io.itk import load_sitk_as_array, load_sitk
from nndet.io.load import save_json, load_json
from nndet.io.paths import get_case_ids_from_dir
from nndet.io.prepare import sitk_copy_metadata
from nndet.io.transforms.instances import instances_to_segmentation_np
Pathlike = Union[str, Path]
......
import os
import shutil
from typing import Sequence, Dict, Optional
import SimpleITK as sitk
from pathlib import Path
from loguru import logger
from nndet.io import save_json
from nndet.io.prepare import instances_from_segmentation
from nndet.utils.check import env_guard
from nndet.utils.info import maybe_verbose_iterable
from nndet.utils.clustering import seg_to_instances, remove_classes, reorder_classes
def instances_from_segmentation(
source_file: Path, output_folder: Path,
rm_classes: Sequence[int] = None,
ro_classes: Dict[int, int] = None,
subtract_one_of_classes: bool = True,
fg_vs_bg: bool = False,
file_name: Optional[str] = None
):
"""
1. Optionally removes classes from the segmentation (
e.g. organ segmentation's which are not useful for detection)
2. Optionally reorders the segmentation indices
3. Converts semantic segmentation to instance segmentation's via
connected components
Args:
source_file: path to semantic segmentation file
output_folder: folder where processed file will be saved
rm_classes: classes to remove from semantic segmentation
ro_classes: reorder classes before instances are generated
subtract_one_of_classes: subtracts one from the classes
in the instance mapping (detection networks assume
that classes start from 0)
fg_vs_bg: map all foreground classes to a single class to run
foreground vs background detection task.
file_name: name of saved file (without file type!)
"""
if subtract_one_of_classes and fg_vs_bg:
logger.info("subtract_one_of_classes will be ignored because fg_vs_bg is "
"active and all foreground classes ill be mapped to 0")
seg_itk = sitk.ReadImage(str(source_file))
seg_npy = sitk.GetArrayFromImage(seg_itk)
if rm_classes is not None:
seg_npy = remove_classes(seg_npy, rm_classes)
if ro_classes is not None:
seg_npy = reorder_classes(seg_npy, ro_classes)
instances, instance_classes = seg_to_instances(seg_npy)
if fg_vs_bg:
num_instances_check = len(instance_classes)
seg_npy[seg_npy > 0] = 1
instances, instance_classes = seg_to_instances(seg_npy)
num_instances = len(instance_classes)
if num_instances != num_instances_check:
logger.warning(f"Lost instance: Found {num_instances} instances before "
f"fg_vs_bg but {num_instances_check} instances after it")
if subtract_one_of_classes:
for key in instance_classes.keys():
instance_classes[key] -= 1
if fg_vs_bg:
for key in instance_classes.keys():
instance_classes[key] = 0
seg_itk_new = sitk.GetImageFromArray(instances)
seg_itk_new.SetSpacing(seg_itk.GetSpacing())
seg_itk_new.SetOrigin(seg_itk.GetOrigin())
seg_itk_new.SetDirection(seg_itk.GetDirection())
if file_name is None:
suffix_length = sum(map(len, source_file.suffixes))
file_name = source_file.name[:-suffix_length]
save_json({"instances": instance_classes}, output_folder / f"{file_name}.json")
sitk.WriteImage(seg_itk_new, str(output_folder / f"{file_name}.nii.gz"))
def run_prep_fg_v_bg(
......
......@@ -26,16 +26,15 @@ import numpy as np
import SimpleITK as sitk
from hydra import initialize_config_module
from loguru import logger
from scipy import ndimage
from scipy.ndimage import label
from tqdm import tqdm
from nndet.core.boxes import box_size_np
from nndet.io import get_case_ids_from_dir, load_json, save_json
from nndet.io import save_json
from nndet.io.transforms.instances import get_bbox_np
from nndet.io.itk import copy_meta_data_itk, load_sitk, load_sitk_as_array
from nndet.io.itk import load_sitk, load_sitk_as_array
from nndet.utils.config import compose
from nndet.utils.check import env_guard
from nndet.utils.clustering import seg_to_instances
def prepare_detection_label(case_id: str,
......@@ -66,11 +65,10 @@ def prepare_detection_label(case_id: str,
sitk.WriteImage(stuff_seg_itk, str(label_dir / f"{case_id}_stuff.nii.gz"))
# prepare things information
structure = np.ones([3] * seg.ndim)
things_seg = np.copy(seg)
things_seg[stuff_seg > 0] = 0 # remove all stuff classes from segmentation
instances_not_filtered, _ = label(things_seg, structure=structure)
instances_not_filtered, instances_not_filtered_classes = seg_to_instances(things_seg)
final_mapping = {}
if instances_not_filtered.max() > 0:
boxes = get_bbox_np(instances_not_filtered[None])["boxes"]
......@@ -92,11 +90,8 @@ def prepare_detection_label(case_id: str,
if all(bsize_world[isotopic_axis] > min_size) and (instance_vol > min_vol):
instances[instance_mask] = start_id
single_idx = np.argwhere(instance_mask)[0]
semantic_class = int(seg[tuple(single_idx)])
semantic_class = instances_not_filtered_classes[int(iid)]
final_mapping[start_id] = things_classes.index(semantic_class)
start_id += 1
else:
instances = np.zeros_like(instances_not_filtered)
......@@ -128,9 +123,12 @@ def main():
============================================================================
Needs additional information from dataset.json/.yaml:
`seg2det_stuff`: these are classes which are interpreted semantically
(stuff classes are experimental and will probably changed in
the future)
`seg2det_things`: these are classes which are interpreted as instances
Both entries should be lists with the indices of the respective
classes where the position will determine its new class
(currently only one classes is supported here)
e.g.
`seg2det_stuff`: [2,] -> remap class 2 from semantic segmentation
to new stuff class 1 (stuff classes start at one)
......
......@@ -33,7 +33,7 @@ from loguru import logger
from nndet.evaluator.registry import evaluate_box_dir
from nndet.io import load_pickle, save_pickle, get_task, load_json
from nndet.utils.clustering import instance_results_from_seg
from nndet.utils.clustering import softmax_to_instances
from nndet.utils.config import compose
from nndet.utils.info import maybe_verbose_iterable
......@@ -194,10 +194,12 @@ def import_dir(
stuff=stuff,
)
# for s in maybe_verbose_iterable(source):
# _fn(s, target_dir)
if num_workers > 0:
with Pool(processes=num_workers) as p:
p.starmap(_fn, zip(source, repeat(target_dir)))
else:
for s in maybe_verbose_iterable(source):
_fn(s, target_dir)
def import_single_case(logits_source: Path,
......@@ -240,7 +242,7 @@ def import_single_case(logits_source: Path,
tmp[:, bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], bbox[2][0]:bbox[2][1]] = probs
probs = tmp
res = instance_results_from_seg(probs,
res = softmax_to_instances(probs,
aggregation=aggregation,
min_num_voxel=min_num_voxel,
min_threshold=min_threshold,
......
......@@ -38,80 +38,11 @@ from nndet.planning import PLANNER_REGISTRY
from nndet.planning.experiment.utils import create_labels
from nndet.planning.properties.registry import medical_instance_props
from nndet.io.load import load_pickle, load_npz_looped
from nndet.io.prepare import maybe_split_4d_nifti, instances_from_segmentation
from nndet.io.paths import get_paths_raw_to_split, get_paths_from_splitted_dir, subfiles, get_case_id_from_path
from nndet.preprocessing import ImageCropper
from nndet.utils.check import check_dataset_file, check_data_and_label_splitted
def run_splitting_4d(data_dir: Path, output_dir: Path, num_processes: int) -> None:
"""
Due to historical reasons this framework uses 3D niftis instead of 4D niftis
This function splits present 4D niftis into 3D niftis per channel
Args:
data_dir (str): top directory where data is located
output_dir (str): output directory for splitted data
num_processes (int): number of processes to use to split data
rm_classes: classes to remove from segmentation
ro_classes: reorder classes in segmentation
subtract_one_from_classes: subtract one from all classes in mapping
instances_from_seg: converts semantic segmentations to instance
segmentations via connected components
"""
if output_dir.is_dir():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True)
source_files, target_folders = get_paths_raw_to_split(data_dir, output_dir)
with Pool(processes=num_processes) as p:
p.starmap(maybe_split_4d_nifti, zip(source_files, target_folders))
def prepare_labels(data_dir: Path,
output_dir: Path,
num_processes: int,
rm_classes: Sequence[int],
ro_classes: Dict[int, int],
subtract_one_from_classes: bool,
instances_from_seg: bool = True):
"""
Copy labels to splitted dir.
Optionally, runs connected components and removes classes from
semantic segmentations
Args:
data_dir: path to task base dir
output_dir: base dir of prepared labels
num_processes: number of processes to use
rm_classes: remove specific classes from semantic segmentation.
Can only be used with `instances_from_seg`
ro_classes: reorder classes in semantic segmentation for
connected components. Can only be used with `instances_from_seg`
subtract_one_from_classes: class indices for detection start from 0.
Subtracts 1 from classes extracted from segmentation
instances_from_seg: Run connected components. Defaults to True.
"""
for labels_subdir in ("labelsTr", "labelsTs"):
if not (data_dir / labels_subdir).is_dir():
continue
labels_output_dir = output_dir / labels_subdir
if instances_from_seg:
if not labels_output_dir.is_dir():
labels_output_dir.mkdir(parents=True)
with Pool(processes=num_processes) as p:
paths = list(map(Path, subfiles(data_dir / labels_subdir,
identifier="*.nii.gz", join=True)))
paths = [path for path in paths if not path.name.startswith('.')]
p.starmap(instances_from_segmentation, zip(
paths, repeat(labels_output_dir), repeat(rm_classes),
repeat(ro_classes), repeat(subtract_one_from_classes)))
else:
shutil.copytree(data_dir / labels_subdir, labels_output_dir)
def run_cropping_and_convert(cropped_output_dir: Path,
splitted_4d_output_dir: Path,
data_info: dict,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment