Commit 7246044d authored by mibaumgartner's avatar mibaumgartner
Browse files

Merge remote-tracking branch 'origin/master' into main

parents fcec502f 6f4c3333
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
from typing import List, Union, Sequence
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from loguru import logger
def linear_warm_up(
iteration: int,
initial_lr: float,
num_iterations: int,
final_lr: float,
) -> float:
"""
Linear learning rate warm up
Args:
iteration: current iteration
initial_lr: initial learning rate for poly lr
num_iterations: total number of iterations for of warmup
final_lr: final learning rate of warmup
Returns:
float: learning rate
"""
assert final_lr > initial_lr
if iteration >= num_iterations:
logger.warning(f"WarmUp was stepped too often, {iteration} "
f"but only {num_iterations} were expected!")
return initial_lr + (final_lr - initial_lr) * (float(iteration) / float(num_iterations))
def poly_lr(
iteration: int,
initial_lr: float,
num_iterations: int,
gamma: float,
) -> float:
"""
initial_lr * (1 - epoch / max_epochs) ** gamma
Adapted from
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/learning_rate/poly_lr.py
https://arxiv.org/abs/1904.08128
Args:
iteration: current iteration
initial_lr: initial learning rate for poly lr
num_iterations: total number of iterations of poly lr
gamma: gamma value
Returns:
float: learning rate
"""
if iteration >= num_iterations:
logger.warning(f"PolyLR was stepped too often, {iteration} "
f"but only {num_iterations} were expected! "
f"Using {num_iterations - 1} for lr computation.")
iteration = num_iterations - 1
return initial_lr * (1 - iteration / float(num_iterations)) ** gamma
def cyclic_linear_lr(
iteration: int,
num_iterations_cycle: int,
initial_lr: float,
final_lr: float,
) -> float:
"""
Linearly cycle learning rate
Args:
iteration: current iteration
num_iterations_cycle: number of iterations per cycle
initial_lr: learning rate to start cycle
final_lr: learning rate to end cycle
Returns:
float: learning rate
"""
cycle_iteration = int(iteration) % num_iterations_cycle
lr_multiplier = 1 - (cycle_iteration / float(num_iterations_cycle))
return initial_lr + (final_lr - initial_lr) * lr_multiplier
def cosine_annealing_lr(
iteration: int,
num_iterations: int,
initial_lr: float,
final_lr: float,
):
"""
Cosine annealing NO restarts
Args:
iteration: current iteration
num_iterations: total number of iterations of coine lr
initial_lr: learning rate to start
final_lr: learning rate to end
Returns:
float: learning rate
"""
return final_lr + 0.5 * (initial_lr - final_lr) * (1 + \
math.cos(math.pi * float(iteration) / float(num_iterations)))
class LinearWarmupPolyLR(_LRScheduler):
def __init__(self,
optimizer: Optimizer,
warm_iterations: int,
warm_lr: Union[float, Sequence[float]],
poly_gamma: float,
num_iterations: int,
last_epoch: int = -1,
) -> None:
"""
Linear Warm Up LR -> Poly LR -> Cycle LR
Args:
optimizer: optimizer for lr scheduling
warm_iterations: number of warmup iterations
warm_lr: initial learning rate of warm up
poly_gamma: gamma of poly lr
num_iterations: total number of iterations (including warmup)
last_epoch: The index of the last epoch. Defaults to -1.
"""
self.num_iterations = num_iterations
# warmup
self.warm_iterations = warm_iterations
if not isinstance(warm_lr, list) and not isinstance(warm_lr, tuple):
self.warm_lr = [warm_lr] * len(optimizer.param_groups)
else:
if len(warm_lr) != len(optimizer.param_groups):
raise ValueError("Expected {} warm_lr, but got {}".format(
len(optimizer.param_groups), len(warm_lr)))
self.warm_lr = [warm_lr]
# poly lr
self.poly_iterations = self.num_iterations - self.warm_iterations
self.poly_gamma = poly_gamma
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self) -> List[float]:
"""
Compute current learning rate for each param group
"""
if self.last_epoch < self.warm_iterations:
# warm up period
lrs = [linear_warm_up(
iteration=self._step_count,
initial_lr=self.warm_lr[idx],
num_iterations=self.warm_iterations,
final_lr=base_lr,
) for idx, base_lr in enumerate(self.base_lrs)]
else:
# poly lr phase
lrs = [poly_lr(
iteration=self._step_count - self.warm_iterations,
initial_lr=base_lr,
num_iterations=self.poly_iterations,
gamma=self.poly_gamma,
) for idx, base_lr in enumerate(self.base_lrs)]
return lrs
class CycleLinear(_LRScheduler):
def __init__(self,
optimizer: Optimizer,
cycle_num_iterations: int,
cycle_initial_lr: Union[float, Sequence[float]],
cycle_final_lr:Union[float, Sequence[float]],
last_epoch: int = -1,
) -> None:
"""
Cyclic learning rates with linear decay
Args:
optimizer: optimizer for lr scheduling
cycle_num_iterations: number of iterations per cycle
cycle_initial_lr: initial learning rate of cycle
cycle_final_lr: final learning rate of cycle
last_epoch: The index of the last epoch. Defaults to -1.
"""
# cycle linear lr
self.cycle_num_iterations = cycle_num_iterations
if not isinstance(cycle_initial_lr, list) and not isinstance(cycle_initial_lr, tuple):
self.cycle_initial_lr = [cycle_initial_lr] * len(optimizer.param_groups)
else:
if len(cycle_initial_lr) != len(optimizer.param_groups):
raise ValueError("Expected {} cycle_initial_lr, but got {}".format(
len(optimizer.param_groups), len(cycle_initial_lr)))
self.cycle_initial_lr = [cycle_initial_lr]
if not isinstance(cycle_final_lr, list) and not isinstance(cycle_final_lr, tuple):
self.cycle_final_lr = [cycle_final_lr] * len(optimizer.param_groups)
else:
if len(cycle_final_lr) != len(optimizer.param_groups):
raise ValueError("Expected {} cycle_final_lr, but got {}".format(
len(optimizer.param_groups), len(cycle_final_lr)))
self.cycle_final_lr = [cycle_final_lr]
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self) -> List[float]:
"""
Compute current learning rate for each param group
"""
lrs = [cyclic_linear_lr(
iteration=max(self._step_count - 1, 0), # init steps once
num_iterations_cycle=self.cycle_num_iterations,
initial_lr=self.cycle_initial_lr[idx],
final_lr=self.cycle_final_lr[idx],
) for idx, base_lr in enumerate(self.base_lrs)]
return lrs
class WarmUpExponential(_LRScheduler):
def __init__(self,
optimizer: Optimizer,
beta2: float,
last_epoch: int = -1,
):
"""
Expoenential learning rate warmup
warmup_lr = base_lr * 1 - exp(- (1 - beta2) * t)
for 2 * (1 - beta2)^(-1) iterations
`On the adequacy of untuned warmup for adaptive optimization`
https://arxiv.org/abs/1910.04209
Args:
optimizer: optimizer to schedule lr from (best used with Adam,
AdamW)
beta2: second beta param of Adam optimizer.
last_epoch: The index of the last epoch. Defaults to -1.
"""
self.iterations = int(2. * (1. / (1. - beta2)))
self.beta2 = beta2
logger.info(f"Running exponential warmup for {self.iterations} iterations")
self.finished = False
super().__init__(optimizer=optimizer, last_epoch=last_epoch)
def get_lr(self) -> List[float]:
"""
Compute current learning rate for each param group
"""
# last epoch is automatically handled by parent class
return [base_lr * (1 - math.exp(- (1 - self.beta2) * self.last_epoch))
for base_lr in zip(self.base_lrs)]
from nndet.training.optimizer.utils import (
get_params_no_wd_on_norm, identify_parameters, change_output_layer,
freeze_layers, unfreeze_layers,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Dict, Sequence
import torch
import torch.nn as nn
import nndet.arch.layers.norm as an
NORM_TYPES = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
nn.LayerNorm, nn.GroupNorm, nn.SyncBatchNorm, nn.LocalResponseNorm,
an.GroupNorm,
]
def get_params_no_wd_on_norm(model: torch.nn.Module, weight_decay: float):
"""
Apply weight decay to model but skip normalization layers
Args:
model (torch.nn.Module) : module for parameters
weight_decay (float) : weight decay for other parameters
Returns:
dict: dict with params and weight decay
See Also:
https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/2
"""
identify_parameters(model, {"no_wd": NORM_TYPES})
return [
{'params': filter(lambda p: not hasattr(p, "no_wd"), model.parameters()), 'weight_decay': weight_decay},
{'params': filter(lambda p: hasattr(p, "no_wd"), model.parameters()), 'weight_decay': 0.},
]
def identify_parameters(model: torch.nn.Module,
type_mapping: Dict[str, Sequence],
check_param_exist: bool = True):
"""
Add attribute to searched module types (can be used to filter for specific modules in parameter list)
Args:
model: module to add attributes to
type_mapping: items specify types of modules to search, key specifies name of attribute
check_param_exist: check if module already has attribute. Can be used to assure that
attributes are not overwritten, but can lead to wrong results for shared parameters and
non "primitive" types
"""
for module in model.modules():
for _name, _types in type_mapping.items():
if any([isinstance(module, _type) for _type in _types]):
for param in module.parameters():
if check_param_exist:
assert not hasattr(param, _name)
setattr(param, _name, True)
def change_output_layer(model: torch.nn.Module, layer_name: str = "fc",
output_channels: int = 2, layer_type=torch.nn.Linear,
**kwargs) -> None:
"""
Change layer of module
Args:
model (torch.nn.Module): module where layer should be exchanged
layer_name (str): name of layer to exchange
output_channels (int): number of new output channels
layer_type (class): class of new layer
**kwargs: keyword arguments passed to constructor of new layer
"""
if not hasattr(model, layer_name):
raise ValueError(f"Model does not have layer {layer_name}.")
old_layer = getattr(model, layer_name)
input_channels = old_layer.in_features
setattr(model, layer_name,
layer_type(input_channels, output_channels, **kwargs))
def freeze_layers(model: torch.nn.Module) -> None:
"""
Freeze layers
Use something like "Optim([p for p in self.parameters() if p.requires_grad])"
to be sure.
Args:
model(torch.nn.Module): module to freeze
"""
for param in model.parameters():
param.requires_grad = False
def unfreeze_layers(model: torch.nn.Module) -> None:
"""
Unfreeze layers
Use something like "Optim([p for p in self.parameters() if p.requires_grad])"
to be sure.
Args:
model(torch.nn.Module): module to freeze
"""
for param in model.parameters():
param.requires_grad = True
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import abstractmethod
from typing import Optional, Union, Callable
from loguru import logger
import torch
from torch.optim.lr_scheduler import _LRScheduler
from pytorch_lightning.callbacks import StochasticWeightAveraging
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_warn
from nndet.training.learning_rate import CycleLinear
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
class BaseSWA(StochasticWeightAveraging):
def __init__(
self,
swa_epoch_start: int,
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
update_statistics: Optional[bool] = False,
):
"""
New Base Class for Stochastic Weighted Averaging
Args:
swa_epoch_start: Epoch to start SWA weight saving.
avg_fn: Function to average saved weights. Defaults to None.
device: Device to save averaged model. Defaults to
torch.device("cpu").
update_statistics: Perform a final update of the normalization
layers. Defaults to None.
Notes: Does not support updating of norm weights after training
"""
super().__init__(
swa_epoch_start=swa_epoch_start,
swa_lrs=None,
annealing_epochs=10,
annealing_strategy="cos",
avg_fn=avg_fn,
device=device,
)
self.update_statistics = update_statistics
logger.info(f"Initialize SWA with swa epoch start {self.swa_start}")
def pl_module_contains_batch_norm(self, pl_module: 'pl.LightningModule'):
if self.update_statistics:
raise NotImplementedError("Updating the statistis of the "
"normalization layer is not suported yet.")
else:
return self.update_statistics
def on_train_epoch_start(self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
"""
Repalce current lr scheduler with SWA scheduler
"""
if trainer.current_epoch == self.swa_start:
optimizer = trainer.optimizers[0]
# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)
_scheduler = self.get_swa_scheduler(optimizer)
self._swa_scheduler = _get_default_scheduler_config()
if not isinstance(_scheduler, dict):
_scheduler = {"scheduler": _scheduler}
self._swa_scheduler.update(_scheduler)
if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0] = self._swa_scheduler
else:
trainer.lr_schedulers.append(self._swa_scheduler)
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
if self.swa_start <= trainer.current_epoch <= self.swa_end:
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)
if trainer.current_epoch == self.swa_end + 1:
raise NotImplementedError("This should never happen (yet)")
@abstractmethod
def get_swa_scheduler(self, optimizer) -> Union[_LRScheduler, dict]:
"""
Generate LR scheduler for SWA
Args:
optimizer: optimizer to wrap
Returns:
Union[_LRScheduler, dict]: If a lr scheduler is returned it will
be stepped once per epoch. Can also return a whole config of
the scheduler to customize steps.
"""
raise NotImplementedError
class SWACycleLinear(BaseSWA):
def __init__(self,
swa_epoch_start: int,
cycle_initial_lr: float,
cycle_final_lr: float,
num_iterations_per_epoch: int,
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
update_statistics: Optional[bool] = None,
):
"""
SWA based on :class:`CycleLinear`
Args:
swa_epoch_start: Epoch to start SWA weight saving.
cycle_initial_lr: initial learning rate of cycle
cycle_final_lr: final learning rate of cycle
num_iterations_per_epoch: number of train iterations per epoch
avg_fn: Function to average saved weights. Defaults to None.
device: Device to save averaged model. Defaults to
torch.device("cpu").
update_statistics: Perform a final update of the normalization
layers. Defaults to None.
"""
super().__init__(
swa_epoch_start=swa_epoch_start,
avg_fn=avg_fn,
device=device,
update_statistics=update_statistics,
)
self.cycle_initial_lr = cycle_initial_lr
self.cycle_final_lr = cycle_final_lr
self.num_iterations_per_epoch = num_iterations_per_epoch
def get_swa_scheduler(self, optimizer) -> Union[_LRScheduler, dict]:
return {
"scheduler": CycleLinear(
optimizer=optimizer,
cycle_num_iterations=self.num_iterations_per_epoch,
cycle_initial_lr=self.cycle_initial_lr,
cycle_final_lr=self.cycle_final_lr,
),
"interval": "step",
}
from nndet.utils.tensor import (
make_onehot_batch,
to_dtype,
to_device,
to_numpy,
to_tensor,
cat,
)
from nndet.utils.info import (
maybe_verbose_iterable,
find_name,
log_git,
get_cls_name,
log_error,
file_logger,
)
from nndet.utils.timer import Timer
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
This is prototype code ... Use at your own risk
This was initially part of a notebook but I needed to move it into
this scriptish functions to run it in my default pipeline
"""
import pickle
from itertools import product
from pathlib import Path
from typing import Sequence, Optional, Tuple
from collections import defaultdict
from loguru import logger
import numpy as np
import pandas as pd
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
import matplotlib.pyplot as plt
plt.style.use('seaborn-deep')
from sklearn.metrics import confusion_matrix
from torch import Tensor
import SimpleITK as sitk
from nndet.core.boxes import box_iou_np, box_size_np
from nndet.io.load import load_pickle, save_json
from nndet.utils.info import maybe_verbose_iterable
def collect_overview(prediction_dir: Path, gt_dir: Path,
iou: float, score: float,
max_num_fp_per_image: int = 5,
top_n: int = 10,
):
results = defaultdict(dict)
for f in prediction_dir.glob("*_boxes.pkl"):
case_id = f.stem.rsplit('_', 1)[0]
gt_data = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
gt_boxes = gt_data["boxes"]
gt_classes = gt_data["classes"]
gt_ignore = [np.zeros(gt_boxes_img.shape[0]).reshape(-1, 1) for gt_boxes_img in [gt_boxes]]
case_result = load_pickle(f)
pred_boxes = case_result["pred_boxes"]
pred_scores = case_result["pred_scores"]
pred_labels = case_result["pred_labels"]
keep = pred_scores > score
pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]
# if "properties" in case_data:
# results[case_id]["orig_spacing"] = case_data["properties"]["original_spacing"]
# results[case_id]["crop_shape"] = [c[1] for c in case_data["properties"]["crop_bbox"]]
# else:
# results[case_id]["orig_spacing"] = None
# results[case_id]["crop_shape"] = None
results[case_id]["num_gt"] = len(gt_classes)
# computation stats here
if gt_boxes.size == 0:
idx = np.argsort(pred_scores)[::-1][:5]
results[case_id]["fp_score"] = pred_scores[idx]
results[case_id]["fp_label"] = pred_labels[idx]
results[case_id]["fp_true_label"] = (np.ones(len(pred_labels)) * -1)
results[case_id]["fp_type"] = ["fp_iou"] * len(pred_labels)
results[case_id]["num_fn"] = 0
elif pred_boxes.size == 0:
results[case_id]["num_fn"] = len(gt_classes)
results[case_id]["fn_boxes"] = gt_boxes
else:
match_quality_matrix = box_iou_np(gt_boxes, pred_boxes)
matched_idxs = np.argmax(match_quality_matrix, axis=0)
matched_vals = np.max(match_quality_matrix, axis=0)
matched_idxs[matched_vals < iou] = -1
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clip(min=0)]
target_labels = gt_classes[matched_idxs.clip(min=0)]
target_labels[matched_idxs == -1] = -1
# True positive analysis
tp_keep = target_labels == pred_labels
tp_boxes, tp_scores, tp_labels = pred_boxes[tp_keep], pred_scores[tp_keep], pred_labels[tp_keep]
keep_high = tp_scores > 0.5
tp_high_boxes, tp_high_scores, tp_high_labels = tp_boxes[keep_high], tp_scores[keep_high], tp_labels[
keep_high]
keep_low = tp_scores < 0.5
tp_low_boxes, tp_low_scores, tp_low_labels = tp_boxes[keep_low], tp_scores[keep_low], tp_labels[keep_low]
high_idx = np.argsort(tp_high_scores)[::-1][:3]
low_idx = np.argsort(tp_low_scores)[:3]
results[case_id]["iou_tp"] = int(tp_keep.sum())
results[case_id]["tp_high_boxes"] = tp_high_boxes[high_idx]
results[case_id]["tp_high_score"] = tp_high_scores[high_idx]
results[case_id]["tp_high_label"] = tp_high_labels[high_idx]
results[case_id]["tp_iou"] = matched_vals[tp_keep]
if tp_low_boxes.size > 0:
results[case_id]["tp_low_boxes"] = tp_low_boxes[low_idx]
results[case_id]["tp_low_score"] = tp_low_scores[low_idx]
results[case_id]["tp_low_label"] = tp_low_labels[low_idx]
# False Positive Analysis
fp_keep = (pred_labels != target_labels) * (pred_labels != -1)
fp_boxes, fp_scores, fp_labels, fp_target_labels = pred_boxes[fp_keep], pred_scores[fp_keep], pred_labels[
fp_keep], target_labels[fp_keep]
idx = np.argsort(fp_scores)[::-1][:max_num_fp_per_image]
# results[case_id]["fp_box"] = fp_boxes[idx]
results[case_id]["fp_score"] = fp_scores[idx]
results[case_id]["fp_label"] = fp_labels[idx]
results[case_id]["fp_true_label"] = fp_target_labels[idx]
results[case_id]["fp_type"] = ["fp_iou" if tl == -1 else "fp_cls" for tl in fp_target_labels]
# Misc
unmatched_gt = (match_quality_matrix.max(axis=1) < iou)
false_negatives = unmatched_gt.sum()
results[case_id]["fn_boxes"] = gt_boxes[unmatched_gt]
results[case_id]["num_fn"] = false_negatives
df = pd.DataFrame.from_dict(results, orient='index')
df = df.sort_index()
analysis_ids = {}
if "fp_score" in list(df.columns):
tmp = df["fp_score"].apply(lambda x: np.max(x) if np.any(x) else 0).nlargest(top_n)
analysis_ids["top_scoring_fp"] = tmp.index.values.tolist()
tmp = df["fp_score"].apply(
lambda x: len(x) if isinstance(x, Sequence) or isinstance(x, np.ndarray) else 0).nlargest(top_n)
analysis_ids["top_num_fp"] = tmp.index.values.tolist()
if "fp_score" in list(df.columns):
tmp = df["num_fn"].nlargest(top_n)
analysis_ids["top_num_fn"] = tmp.index.values.tolist()
return df, analysis_ids
def collect_score_iou(prediction_dir: Path, gt_dir: Path, iou: float, score: float):
all_pred = []
all_target = []
all_pred_ious = []
all_pred_scores = []
for f in prediction_dir.glob("*_boxes.pkl"):
case_id = f.stem.rsplit('_', 1)[0]
gt_data = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
gt_boxes = gt_data["boxes"]
gt_classes = gt_data["classes"]
gt_ignore = [np.zeros(gt_boxes_img.shape[0]).reshape(-1, 1) for gt_boxes_img in [gt_boxes]]
case_result = load_pickle(f)
pred_boxes = case_result["pred_boxes"]
pred_scores = case_result["pred_scores"]
pred_labels = case_result["pred_labels"]
keep = pred_scores > score
pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]
# computation starts here
if gt_boxes.size == 0:
all_pred.append(pred_labels)
all_target.append(np.ones(len(pred_labels)) * -1)
all_pred_ious.append(np.zeros(len(pred_labels)))
all_pred_scores.append(pred_scores)
elif pred_boxes.size == 0:
all_pred.append(np.ones(len(gt_classes)) * -1)
all_target.append(gt_classes)
else:
match_quality_matrix = box_iou_np(gt_boxes, pred_boxes)
matched_idxs = np.argmax(match_quality_matrix, axis=0)
matched_vals = np.max(match_quality_matrix, axis=0)
matched_idxs[matched_vals < iou] = -1
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clip(min=0)]
target_labels = gt_classes[matched_idxs.clip(min=0)]
target_labels[matched_idxs == -1] = -1
all_pred.append(pred_labels)
all_target.append(target_labels)
all_pred_ious.append(matched_vals)
all_pred_scores.append(pred_scores)
false_negatives = (match_quality_matrix.max(axis=1) < iou).sum()
if false_negatives > 0: # false negatives
all_pred.append(np.ones(false_negatives) * -1)
all_target.append(np.zeros(false_negatives))
return all_pred, all_target, all_pred_ious, all_pred_scores
def plot_confusion_matrix(all_pred, all_target, iou: float, score:float):
if len(all_pred) > 0 and len(all_target) > 0:
cm = confusion_matrix(np.concatenate(all_target), np.concatenate(all_pred))
plt.figure()
ax = sns.heatmap(cm, annot=True, cbar=False)
ax.set_xlabel("Prediction")
ax.set_ylabel("Ground Truth")
ax.set_title(f"Confusion Matrix IoU {iou} and Score Threshold {score}")
return ax
else:
return None
def plot_joint_iou_score(all_pred_ious, all_pred_scores):
if isinstance(all_pred_ious, Sequence):
if len(all_pred_ious) == 0:
return None
all_pred_ious = np.concatenate(all_pred_ious)
if isinstance(all_pred_scores, Sequence):
if len(all_pred_scores) == 0:
return None
all_pred_scores = np.concatenate(all_pred_scores)
plt.figure()
f = sns.jointplot(x=all_pred_ious, y=all_pred_scores,
xlim=(-0.01, 1.01), ylim=(-0.01, 1.01), marginal_kws={"bins": 10},
kind='reg', scatter=True,
)
plt.plot([0, 1], [0, 1], 'g')
f.set_axis_labels("IoU", "Predicted Score")
f.ax_joint.axvline(x=0.1, c='r')
f.ax_joint.axvline(x=0.5, c='r')
f.fig.subplots_adjust(top=0.9)
f.fig.suptitle('Class independent predicted score over IoU plot', fontsize=16)
return f
def collect_boxes(prediction_dir: Path, gt_dir: Path, iou:float, score: float):
all_pred = []
all_target = []
all_boxes = []
i = 0
for f in prediction_dir.glob("*_boxes.pkl"):
case_id = f.stem.rsplit('_', 1)[0]
gt_data = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
gt_boxes = gt_data["boxes"]
gt_classes = gt_data["classes"]
gt_ignore = [np.zeros(gt_boxes_img.shape[0]).reshape(-1, 1) for gt_boxes_img in [gt_boxes]]
case_result = load_pickle(f)
pred_boxes = case_result["pred_boxes"]
pred_scores = case_result["pred_scores"]
pred_labels = case_result["pred_labels"]
keep = pred_scores > score
pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]
# computation starts here
if gt_boxes.size == 0:
all_pred.append(pred_labels)
all_target.append(np.ones(len(pred_labels)) * -1)
all_boxes.append(pred_boxes)
elif pred_boxes.size == 0:
all_pred.append(np.ones(len(gt_classes)) * -1)
all_target.append(gt_classes)
all_boxes.append(gt_boxes)
else:
match_quality_matrix = box_iou_np(gt_boxes, pred_boxes)
matched_idxs = np.argmax(match_quality_matrix, axis=0)
matched_vals = np.max(match_quality_matrix, axis=0)
matched_idxs[matched_vals < iou] = -1
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clip(min=0)]
target_labels = gt_classes[matched_idxs.clip(min=0)]
target_labels[matched_idxs == -1] = -1
all_pred.append(pred_labels)
all_target.append(target_labels)
all_boxes.append(pred_boxes)
unmatched_gt = (match_quality_matrix.max(axis=1) < iou)
false_negatives = unmatched_gt.sum()
if false_negatives > 0: # false negatives
all_pred.append(np.ones(false_negatives) * -1)
all_target.append(np.zeros(false_negatives))
all_boxes.append(gt_boxes[np.nonzero(unmatched_gt)[0]])
return all_pred, all_target, all_boxes
def plot_sizes(all_pred, all_target, all_boxes, iou, score):
if len(all_pred) == 0 or len(all_target) == 0:
return None, None
_all_pred = np.concatenate(all_pred)
_all_target = np.concatenate(all_target)
_all_boxes = np.concatenate([ab for ab in all_boxes if ab.size > 0])
dists = box_size_np(_all_boxes)
tp_mask = _all_pred == _all_target
fp_mask = (_all_pred != _all_target) * (_all_pred != -1)
fn_mask = (_all_pred != _all_target) * (_all_pred == -1)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dists[tp_mask, 0], dists[tp_mask, 1], dists[tp_mask, 2], c='g', marker='o', label="tp")
ax.scatter(dists[fp_mask, 0], dists[fp_mask, 1], dists[fp_mask, 2], c='r', marker='x', label="fp")
ax.scatter(dists[fn_mask, 0], dists[fn_mask, 1], dists[fn_mask, 2], c='b', marker='^', label="fn")
ax.set_title(
f"IoU {iou} and Score Threshold {score}: tp {sum(tp_mask)} fp {sum(fp_mask)} fn {sum(fn_mask)}")
ax.set_xlabel('bounding box size axis 0')
ax.set_ylabel('bounding box size axis 1')
ax.set_zlabel('bounding box size axis 2')
ax.legend()
return fig, ax
def plot_sizes_bar(all_pred, all_target, all_boxes, iou, score,
max_bin: Optional[int] = None ):
if len(all_pred) == 0 or len(all_target) == 0:
return None, None
_all_pred = np.concatenate(all_pred)
_all_target = np.concatenate(all_target)
_all_boxes = np.concatenate([ab for ab in all_boxes if ab.size > 0])
dists = box_size_np(_all_boxes)
tp_mask = _all_pred == _all_target
fp_mask = (_all_pred != _all_target) * (_all_pred != -1)
fn_mask = (_all_pred != _all_target) * (_all_pred == -1)
fig = plt.figure()
# ax = fig.add_subplot(111)
data = {
"tp": dists[tp_mask, 0] + dists[tp_mask, 1] + dists[tp_mask, 2],
"fp": dists[fp_mask, 0] + dists[fp_mask, 1] + dists[fp_mask, 2],
"fn": dists[fn_mask, 0] + dists[fn_mask, 1] + dists[fn_mask, 2],
}
# plt.hist(x=[data["tp"], data["fp"], data["fn"]],
# bins=100, label=["tp", "fp", "fn"], stacked=False,
# histtype="step",
# )
# ax = plt.gca()
kwargs = {}
if max_bin is not None:
kwargs["binrange"] = [0, max_bin]
ax = sns.histplot(data=data, bins=100, element="step",
palette={"tp": "g", "fp": "r", "fn": "b"},
legend=True, fill=False, **kwargs
)
ax.set_title(
f"IoU {iou} and Score Threshold {score}: tp {sum(tp_mask)} fp {sum(fp_mask)} fn {sum(fn_mask)}")
ax.set_xlabel("box width + height ( + depth)")
ax.set_ylabel("Count")
return fig, ax
def run_analysis_suite(prediction_dir: Path, gt_dir: Path, save_dir: Path):
for iou, score in maybe_verbose_iterable(list(product([0.1, 0.5], [0.1, 0.5]))):
_save_dir = save_dir / f"iou_{iou}_score_{score}"
_save_dir.mkdir(parents=True, exist_ok=True)
found_predictions = list(prediction_dir.glob("*_boxes.pkl"))
logger.info(f"Found {len(found_predictions)} predictions for analysis")
df, analysis_ids = collect_overview(prediction_dir, gt_dir,
iou=iou, score=score,
max_num_fp_per_image=5,
top_n=10,
)
df.to_json(_save_dir / "analysis.json", indent=4, orient='index')
df.to_csv(_save_dir / "analysis.csv")
save_json(analysis_ids, _save_dir / "analysis_ids.json")
all_pred, all_target, all_pred_ious, all_pred_scores = collect_score_iou(
prediction_dir, gt_dir, iou=iou, score=score)
confusion_ax = plot_confusion_matrix(all_pred, all_target, iou=iou, score=score)
plt.savefig(_save_dir / "confusion_matrix.png")
plt.close()
iou_score_ax = plot_joint_iou_score(all_pred_ious, all_pred_scores)
plt.savefig(_save_dir / "joint_iou_score.png")
plt.close()
all_pred, all_target, all_boxes = collect_boxes(
prediction_dir, gt_dir, iou=iou, score=score)
sizes_fig, sizes_ax = plot_sizes(all_pred, all_target, all_boxes, iou=iou, score=score)
plt.savefig(_save_dir / "sizes.png")
with open(str(_save_dir / 'sizes.pkl'), "wb") as fp:
pickle.dump(sizes_fig, fp, protocol=4)
plt.close()
sizes_fig, sizes_ax = plot_sizes_bar(all_pred, all_target, all_boxes, iou=iou, score=score)
plt.savefig(_save_dir / "sizes_bar.png")
with open(str(_save_dir / 'sizes_bar.pkl'), "wb") as fp:
pickle.dump(sizes_fig, fp, protocol=4)
plt.close()
sizes_fig, sizes_ax = plot_sizes_bar(all_pred, all_target, all_boxes,
iou=iou, score=score, max_bin=100)
plt.savefig(_save_dir / "sizes_bar_100.png")
with open(str(_save_dir / 'sizes_bar_100.pkl'), "wb") as fp:
pickle.dump(sizes_fig, fp, protocol=4)
plt.close()
def convert_box_to_nii_meta(pred_boxes: Tensor,
pred_scores: Tensor,
pred_labels: Tensor,
props: dict,
) -> Tuple[sitk.Image, dict]:
instance_mask = np.zeros(props["original_size_of_raw_data"], dtype=np.uint8)
for instance_id, pbox in enumerate(pred_boxes, start=1):
mask_slicing = [slice(int(pbox[0]), int(pbox[2])),
slice(int(pbox[1]), int(pbox[3])),
]
if instance_mask.ndim == 3:
mask_slicing.append(slice(int(pbox[4]), int(pbox[5])))
instance_mask[tuple(mask_slicing)] = instance_id
logger.info(f"Created instance mask with {instance_mask.max()} instances.")
instance_mask_itk = sitk.GetImageFromArray(instance_mask)
instance_mask_itk.SetOrigin(props["itk_origin"])
instance_mask_itk.SetDirection(props["itk_direction"])
instance_mask_itk.SetSpacing(props["itk_spacing"])
prediction_meta = {idx: {"score": float(score), "label": int(label)}
for idx, (score, label) in enumerate(zip(pred_scores, pred_labels), start=1)}
return instance_mask_itk, prediction_meta
import functools
import os
import warnings
from pathlib import Path
from typing import List, Sequence, Optional
import numpy as np
import SimpleITK as sitk
from nndet.io import load_json, load_sitk
from nndet.io.paths import get_task, get_paths_from_splitted_dir
from nndet.utils.config import load_dataset_info
from nndet.utils.info import maybe_verbose_iterable
def env_guard(func):
"""
Contextmanager to check nnDetection environment variables
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# we use print here because logging might not be initialized yet and
# this is intended as a user warning.
# det_data
if os.environ.get("det_data", None) is None:
raise RuntimeError(
"'det_data' environment variable not set. "
"Please refer to the installation instructions. "
)
# det_models
if os.environ.get("det_models", None) is None:
raise RuntimeError(
"'det_models' environment variable not set. "
"Please refer to the installation instructions. "
)
# OMP_NUM_THREADS
if os.environ.get("OMP_NUM_THREADS", None) is None:
raise RuntimeError(
"'OMP_NUM_THREADS' environment variable not set. "
"Please refer to the installation instructions. "
)
# det_num_threads
if os.environ.get("det_num_threads", None) is None:
warnings.warn(
"Warning: 'det_num_threads' environment variable not set. "
"Please read installation instructions again. "
"Training will not work properly.")
# det_verbose
if os.environ.get("det_verbose", None) is None:
print("'det_verbose' environment variable not set. "
"Continue in verbose mode.")
return func(*args, **kwargs)
return wrapper
def _check_key_missing(cfg: dict, key: str, ktype=None):
if key not in cfg:
raise ValueError(f"Dataset information did not contain "
f"'{key}' key, found {list(cfg.keys())}")
if ktype is not None:
if not isinstance(cfg[key], ktype):
raise ValueError(f"Found {key} of type {type(cfg[key])} in "
f"dataset information but expected type {ktype}")
def check_dataset_file(task_name: str):
"""
Run a sequence of checks to confirm correct format of dataset information
Args:
task_name: task identifier to check info for
"""
print("Start dataset info check.")
cfg = load_dataset_info(get_task(task_name))
_check_key_missing(cfg, "task", ktype=str)
_check_key_missing(cfg, "dim", ktype=int)
_check_key_missing(cfg, "labels", ktype=dict)
_check_key_missing(cfg, "modalities", ktype=dict)
# check dim
if dim := cfg["dim"] not in [2, 3]:
raise ValueError(f"Found dim {dim} in dataset info but only support dim=2 or dim=3.")
# check labels
for key, item in cfg["labels"].items():
if not isinstance(key, (str, int)):
raise ValueError("Expected key of type string in dataset "
f"info labels but found {type(key)} : {key}")
if not isinstance(item, (str, int)):
raise ValueError("Expected name of type string in dataset "
f"info labels but found {type(item)} : {item}")
found_classes = sorted(list(map(int, cfg["labels"].keys())))
for ic, idx in enumerate(found_classes):
if ic != idx:
raise ValueError("Found wrong order of label classes in dataset info."
f"Found {found_classes} but expected {list(range(len(found_classes)))}")
# check modalities
for key, item in cfg["modalities"].items():
if not isinstance(key, (str, int)):
raise ValueError("Expected key of type string in dataset "
f"info labels but found {type(key)} : {key}")
if not isinstance(item, (str, int)):
raise ValueError("Expected name of type string in dataset "
f"info labels but found {type(item)} : {item}")
found_mods = sorted(list(map(int, cfg["modalities"].keys())))
for ic, idx in enumerate(found_classes):
if ic != idx:
raise ValueError("Found wrong order of modalities in dataset info."
f"Found {found_mods} but expected {list(range(len(found_mods)))}")
# check target class
target_class = cfg.get("target_class", None)
if target_class is not None and not isinstance(target_class, int):
raise ValueError("If target class is defined, it needs to be an integer, "
f"found {type(target_class)} : {target_class}")
print("Dataset info check complete.")
def check_data_and_label_splitted(
task_name: str,
test: bool = False,
labels: bool = True,
full_check: bool = True,
):
"""
Perform checks of data and label in raw splitted format
Args:
task_name: name of task to check
test: check test data
labels: check labels
full_check: Per default a full check will be performed which needs to
load all files. If this is disabled, a computationall light check
will be performed
Raises:
ValueError: if not all raw splitted files were found
ValueError: missing label info file
ValueError: instances in label info file need to start at 1
ValueError: instances in label info file need to be consecutive
"""
print("Start data and label check.")
cfg = load_dataset_info(get_task(task_name))
splitted_paths = get_paths_from_splitted_dir(
num_modalities=len(cfg["modalities"]),
splitted_4d_output_dir=Path(os.getenv('det_data')) / task_name / "raw_splitted",
labels=labels,
test=test,
)
for case_paths in maybe_verbose_iterable(splitted_paths):
# check all files exist
for cp in case_paths:
if not Path(cp).is_file():
raise ValueError(f"Expected {cp} to be a raw splitted "
"data path but it does not exist.")
if labels:
# check label info (json files)
mask_path = case_paths[-1]
mask_info_path = mask_path.parent / f"{mask_path.stem.split('.')[0]}.json"
if not Path(mask_info_path).is_file():
raise ValueError(f"Expected {mask_info_path} to be a raw splitted "
"mask info path but it does not exist.")
mask_info = load_json(mask_info_path)
if mask_info["instances"]:
mask_info_instances = list(map(int, mask_info["instances"].keys()))
if j := not min(mask_info_instances) == 1:
raise ValueError(f"Instance IDs need to start at 1, found {j} in {mask_info_path}")
for i in range(1, len(mask_info_instances) + 1):
if i not in mask_info_instances:
raise ValueError(f"Exptected {i} to be an Instance ID in "
f"{mask_info_path} but only found {mask_info_instances}")
else:
mask_info_path = None
if full_check:
_full_check(case_paths, mask_info_path)
print("Data and label check complete.")
def _full_check(
case_paths: List[Path],
mask_info_path: Optional[Path] = None,
) -> None:
"""
Performas itk and instance chekcs on provided paths
Args:
case_paths: paths to all itk images to check properties
if label is provided it needs to be at the last position
mask_info_path: optionally check label properties. If None, no
check of label properties will be performed.
Raises:
ValueError: Inconsistent instances in label info and label image
See also:
:func:`_check_itk_params`
"""
img_itk_seq = [load_sitk(cp) for cp in case_paths]
_check_itk_params(img_itk_seq, case_paths)
if mask_info_path is not None:
mask_itk = img_itk_seq[-1]
mask_info = load_json(mask_info_path)
info_instances = list(map(int, mask_info["instances"].keys()))
mask_instances = np.unique(sitk.GetArrayViewFromImage(mask_itk))
mask_instances = mask_instances[mask_instances > 0]
for mi in mask_instances:
if not mi in info_instances:
raise ValueError(f"Found instance ID {mi} in mask which is "
f"not present in info {info_instances} in {mask_info_path}")
if not len(info_instances) == len(mask_instances):
raise ValueError("Found instances in info which are not present in mask: "
f"mask: {mask_instances} info {info_instances} in {mask_info_path}")
def _check_itk_params(
img_seq: Sequence[sitk.Image],
paths: Sequence[Path],
) -> None:
"""
Check Dimension, Origin, Direction and Spacing of a Sequence of images
Args:
img_seq: sequence of images to check
paths: correcponding paths of images (for error msg)
Raises:
ValueError: raised if dimensions do not match
ValueError: raised if origin does not match
ValueError: raised if direction does not match
ValueError: raised if spacing does not match
"""
for idx, img in enumerate(img_seq[1:], start=1):
if not (np.asarray(img_seq[0].GetDimension()) == \
np.asarray(img.GetDimension())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same dimensions!")
if not (np.asarray(img_seq[0].GetOrigin()) == \
np.asarray(img.GetOrigin())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same origin!")
if not (np.asarray(img_seq[0].GetDirection()) == \
np.asarray(img.GetDirection())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same direction!")
if not (np.asarray(img_seq[0].GetSpacing()) == \
np.asarray(img.GetSpacing())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same spacing!")
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from scipy.ndimage import label
from typing import Dict, Sequence, Union, Tuple, Optional
from nndet.io.transforms.instances import get_bbox_np
def seg2instances(seg: np.ndarray,
exclude_background: bool = True,
min_num_voxel: int = 0,
) -> Tuple[np.ndarray, Dict[int, int]]:
"""
Use connected components with ones matrix to created instance from segmentation
Args:
seg: semantic segmentation [spatial dims]
exclude_background: skips background class for the mapping
from instances to classes
min_num_voxel: minimum number of voxels of an instance
Returns:
np.ndarray: instance segmentation
Dict[int, int]: mapping from instances to classes
"""
structure = np.ones([3] * seg.ndim)
instances_temp, _ = label(seg, structure=structure)
instance_ids = np.unique(instances_temp)
if exclude_background:
instance_ids = instance_ids[instance_ids > 0]
instance_classes = {}
instances = np.zeros_like(instances_temp)
i = 1
for iid in instance_ids:
instance_binary_mask = instances_temp == iid
if min_num_voxel > 0:
if instance_binary_mask.sum() < min_num_voxel: # remove small instances
continue
instances[instance_binary_mask] = i # save instance to final mask
single_idx = np.argwhere(instance_binary_mask)[0] # select semantic class
semantic_class = int(seg[tuple(single_idx)])
instance_classes[int(i)] = semantic_class # save class
i = i + 1 # bump instance index
return instances, instance_classes
def remove_classes(seg: np.ndarray, rm_classes: Sequence[int], classes: Dict[int, int] = None,
background: int = 0) -> Union[np.ndarray, Tuple[np.ndarray, Dict[int, int]]]:
"""
Remove classes from segmentation (also works on instances
but instance ids may not be consecutive anymore)
Args:
seg: segmentation [spatial dims]
rm_classes: classes which should be removed
classes: optional mapping from instances from segmentation to classes
background: background value
Returns:
np.ndarray: segmentation where classes are removed
Dict[int, int]: updated mapping from instances to classes
"""
for rc in rm_classes:
seg[seg == rc] = background
if classes is not None:
classes.pop(rc)
if classes is None:
return seg
else:
return seg, classes
def reorder_classes(seg: np.ndarray, class_mapping: Dict[int, int]) -> np.ndarray:
"""
Reorders classes in segmentation
Args:
seg: segmentation
class_mapping: mapping from source id to new id
Returns:
np.ndarray: remapped segmentation
"""
for source_id, target_id in class_mapping.items():
seg[seg == source_id] = target_id
return seg
def compute_score_from_seg(instances: np.ndarray,
instance_classes: Dict[int, int],
probs: np.ndarray,
aggregation: str = "max",
) -> np.ndarray:
"""
Combine scores for each instance given an instance mask and instance logits
Args:
instances: instance segmentation [dims]; dims can be arbitrary dimensions
instance_classes: assign each instance id to a class (id -> class)
probs: predicted probabilities for each class [C, dims];
C = number of classes, dims need to have the same dimensions as
instances
aggregation: defines the aggregation method for the probabilities.
One of 'max', 'mean'
Returns:
Sequence[float]: Probability for each instance
"""
instance_classes = {int(key): int(item) for key, item in instance_classes.items()}
instance_ids = list(instance_classes.keys())
instance_scores = []
for iid in instance_ids:
ic = instance_classes[iid]
instance_mask = instances == iid
instance_probs = probs[ic][instance_mask]
if aggregation == "max":
_score = np.max(instance_probs)
elif aggregation == "mean":
_score = np.mean(instance_probs)
elif aggregation == "median":
_score = np.median(instance_probs)
elif aggregation == "percentile95":
_score = np.percentile(instance_probs, 95)
else:
raise ValueError(f"Aggregation {aggregation} is not aggregation")
instance_scores.append(_score)
return np.asarray(instance_scores)
def instance_results_from_seg(probs: np.ndarray,
aggregation: str,
stuff: Optional[Sequence[int]] = None,
min_num_voxel: int = 0,
min_threshold: Optional[float] = None,
) -> dict:
"""
Compute instance segmentation results from a semantic segmentation
argmax -> remove stuff classes -> connected components ->
aggregate score inside each instance
Args:
probs: Predicted probabilities for each class [C, dims];
C = number of classes, dims can be arbitrary dimensions
aggregation: defines the aggregation method for the probabilities.
One of 'max', 'mean'
stuff: stuff classes to be ignored during conversion.
min_num_voxel: minimum number of voxels of an instance
min_threshold: if None argmax is used. If a threshold is provided
it is used as a probability threshold for the foreground class.
if multiple foreground classes exceed the threshold, the
foreground class with the largest probability is selected.
Returns:
dict: predictions
`pred_instances`: instance segmentation [dims]
`pred_boxes`: predicted bounding boxes [2 * spatial dims]
`pred_labels`: predicted class for each instance/box
`pred_scores`: predicted score for each instance/box
"""
if min_threshold is not None:
if probs.shape[0] > 2:
fg_argmax = np.argmax(probs, axis=0)
fg_mask = np.max(probs[1:], axis=0) > min_threshold
seg = np.zeros_like(probs[0])
seg[fg_mask] = fg_argmax[fg_mask]
else:
seg = probs[1] > min_threshold
else:
seg = np.argmax(probs, axis=0)
if stuff is not None:
for s in stuff:
seg[seg == s] = 0
instances, instance_classes = seg2instances(seg,
exclude_background=True,
min_num_voxel=min_num_voxel,
)
instance_scores = compute_score_from_seg(instances, instance_classes, probs,
aggregation=aggregation)
instance_classes = {int(key): int(item) - 1 for key, item in instance_classes.items()}
tmp = get_bbox_np(instances[None], instance_classes)
instance_boxes = tmp["boxes"]
instance_classes_seq = tmp["classes"]
return {
"pred_instances": instances,
"pred_boxes": instance_boxes,
"pred_labels": instance_classes_seq,
"pred_scores": instance_scores,
}
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import importlib
from pathlib import Path
import yaml
from omegaconf import OmegaConf
from hydra.experimental import compose as hydra_compose
from nndet.io.paths import Pathlike, get_task
def load_dataset_info(task_dir: Pathlike) -> dict:
"""
Load dataset information from a given task directory
Args:
task_dir: path to directory of specific task e.g. ../Task12_LIDC
Returns:
dict: loaded dataset info. Typically includes:
`name` (str): name of dataset
`target_class` (str)
"""
task_dir = Path(task_dir)
yaml_path = task_dir / "dataset.yaml"
json_path = task_dir / "dataset.json"
if yaml_path.is_file():
with open(yaml_path, 'r') as f:
data = yaml.full_load(f)
elif json_path.is_file():
with open(json_path, "r") as f:
data = json.load(f)
else:
raise RuntimeError(f"Did not find dataset.json or dataset.yaml in {task_dir}")
return data
def compose(task, *args, models: bool = False, **kwargs) -> dict:
cfg = hydra_compose(*args, **kwargs)
OmegaConf.set_struct(cfg, False)
task_name = get_task(task, name=True, models=models)
cfg["task"] = task_name
cfg["data"] = load_dataset_info(get_task(task_name))
for imp in cfg.get("additional_imports", []):
print(f"Additional import found {imp}")
importlib.import_module(imp)
return cfg
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import copy
import pathlib
import warnings
import functools
from collections.abc import MutableMapping
from subprocess import PIPE, run
from omegaconf.omegaconf import OmegaConf
from tqdm import tqdm
from typing import Mapping, Sequence, Union, Callable, Any, Iterable
from loguru import logger
from contextlib import contextmanager
from typing import Union, Optional
from pathlib import Path
from git import Repo, InvalidGitRepositoryError
def get_requirements():
"""
Get all installed packages from currently active environment
Returns:
str: list with all requirements
"""
command = ['pip', 'list']
result = run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True)
assert not result.stderr, "stderr not empty"
return result.stdout
def write_requirements_to_file(path: Union[str, Path]) -> None:
"""
Write all installed packages from currently active environment to file
Args:
path (str): path to file (including file name and extension)
"""
with open(path, "w+") as f:
f.write(get_requirements())
def get_repo_info(path: Union[str, Path]):
"""
Parse repository information from path
Args:
path (str): path to repo. If path is not a repository it
searches parent folders for a repository
Returns:
dict: contains the current hash, gitdir and active branch
"""
def find_repo(findpath):
p = Path(findpath).absolute()
for p in [p, *p.parents]:
try:
repo = Repo(p)
break
except InvalidGitRepositoryError:
pass
else:
raise InvalidGitRepositoryError
return repo
repo = find_repo(path)
return {"hash": repo.head.commit.hexsha,
"gitdir": repo.git_dir,
"active_branch": repo.active_branch.name}
def maybe_verbose_iterable(data: Iterable, **kwargs) -> Iterable:
"""
If verbose flag of nndet is enabled, uses tqdm to create a
progress bar
Args:
data: iterable to wrap
**kwargs: keyword arguments passed to tqdm
Returns:
Iterable: maybe iterable with progress bar atteched to it
"""
if bool(int(os.getenv("det_verbose", 1))):
return tqdm(data, **kwargs)
else:
return data
def find_name(tdir: Union[str, Path], name: str,
postfix: Optional[str] = None) -> Path:
"""
Generates non exisitng names for files and dirs by adding a counter to
the end
Args:
tdir: target directory where name should be determined for
name: base name for string
postfix: postfix for name+counter. Defaults to None.
Raises:
RuntimeError: this function only works up to the counter of 1000
Returns:
Path: path to generated item
"""
if not isinstance(tdir, Path):
tdir = Path(tdir)
if not tdir.is_dir():
tdir.mkdir(parents=True)
if postfix is None:
postfix = ""
i=0
while True:
output_dir = tdir / f"{name}{i:03d}{postfix}"
if not output_dir.exists():
break
if i > 1000:
raise RuntimeError(f"Was not able to find name for tdir {tdir} and {name}")
i += 1
return output_dir
def log_git(repo_path: Union[pathlib.Path, str], repo_name: str = None):
"""
Use python logging module to log git information
Args:
repo_path (Union[pathlib.Path, str]): path to repo or file inside repository (repository is recursively searched)
"""
try:
git_info = get_repo_info(repo_path)
return git_info
except Exception:
logger.error("Was not able to read git information, trying to continue without.")
return {}
def get_cls_name(obj: Any, package_name: bool = True) -> str:
"""
Get name of class from object
Args:
obj (Any): any object
package_name (bool): append package origin at the beginning
Returns:
str: name of class
"""
cls_name = str(obj.__class__)
# remove class prefix
cls_name = cls_name.split('\'')[1]
# split modules
cls_split = cls_name.split('.')
if len(cls_split) > 1:
cls_name = cls_split[0] + '.' + cls_split[-1] if package_name else cls_split[-1]
else:
cls_name = cls_split[0]
return cls_name
def log_error(fn: Callable) -> Any:
"""
Log error messages in hydra log when they occur
Args:
fn: function to wrap
Returns:
Any
"""
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
logger.error(str(e))
raise e
return wrapper
@contextmanager
def file_logger(path: Union[str, Path], level: str = "DEBUG", overwrite: bool = True):
"""
context manager to automatically clean up file logger
Args:
path: path to output file
level: logging level. Defaults to "Debug".
Yields:
None
"""
path = Path(path)
if overwrite and path.is_file():
os.remove(path)
logger_id = logger.add(path, level=level)
try:
yield None
finally:
logger.remove(logger_id)
def create_debug_plan(plan: dict) -> str:
_plan = copy.deepcopy(plan)
_plan.pop("dataset_properties", None)
_plan.pop("original_spacings", None)
_plan.pop("original_sizes", None)
return stringify_nested_dict(_plan)
def stringify_nested_dict(data: dict):
if isinstance(data, dict):
return {str(key): stringify_nested_dict(item) for key, item in data.items()}
elif isinstance(data, (list, tuple)):
return [stringify_nested_dict(item) for item in data]
else:
return str(data)
def flatten_mapping(
nested_mapping: Mapping,
sep: str = ".",
) -> Mapping[str, Any]:
_mapping = {}
for key, item in nested_mapping.items():
if isinstance(item, MutableMapping):
for _key, _item in flatten_mapping(item, sep=sep).items():
_mapping[str(key) + sep + str(_key)] = _item
else:
_mapping[str(key)] = item
return _mapping
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import shutil
import json
from itertools import repeat
from multiprocessing import Pool
import SimpleITK as sitk
import numpy as np
from loguru import logger
from pathlib import Path
from typing import Sequence, Union
from nndet.io.itk import load_sitk_as_array, load_sitk
from nndet.io.load import save_json, load_json
from nndet.io.paths import get_case_ids_from_dir
from nndet.io.prepare import sitk_copy_metadata
from nndet.io.transforms.instances import instances_to_segmentation_np
Pathlike = Union[str, Path]
class Exporter:
"""
Helper to export datasets to nnunet
"""
def __init__(self,
data_info: dict,
tr_image_dir: Pathlike,
label_dir: Pathlike,
target_dir: Pathlike,
ts_image_dir: Pathlike = None,
export_stuff: bool = False,
processes: int = 6,
):
"""
Args:
data_info: dataset information. See :method:`export_dataset_info`.
Required keys: `modality`, `labels`
tr_image_dir: training data dir
label_dir: label data dir
target_dir: target directory
ts_image_dir: test data dir
export_stuff: export stuff segmentations
"""
self.data_info = data_info
self.tr_image_dir = Path(tr_image_dir)
self.label_dir = Path(label_dir)
self.target_dir = Path(target_dir)
self.export_stuff = export_stuff
self.processes = processes
if ts_image_dir is not None:
self.ts_image_dir = Path(ts_image_dir)
else:
self.ts_image_dir = None
def export(self):
"""
Export entire dataset
"""
self.export_images()
self.export_labels()
self.export_dataset_info()
def export_images(self):
"""
Export images
"""
# data can be copied directly
for img_dir in [self.tr_image_dir, self.ts_image_dir]:
if img_dir is None:
continue
image_target_dir = self.target_dir / img_dir.stem
logger.info(f"Copy data from {img_dir} to {image_target_dir}")
shutil.copytree(img_dir, image_target_dir)
def export_labels(self):
"""
Export labels
"""
case_ids = get_case_ids_from_dir(self.label_dir, remove_modality=False, pattern="*.json")
label_target_dir = self.target_dir / self.label_dir.stem
label_target_dir.mkdir(exist_ok=True, parents=True)
num_classes = len(self.data_info.get("labels", {}))
if num_classes == 0:
logger.warning(f"Did not find any fg classes.")
logger.info(f"Found {len(case_ids)} to process.")
logger.info(f"Export stuff: {self.export_stuff}")
if self.processes == 0:
logger.info("Using for loop to export labels")
for cid in case_ids:
self._export_label(cid, num_classes, label_target_dir)
else:
logger.info(f"Using pool with {self.processes} processes to export labels")
with Pool(processes=self.processes) as p:
p.starmap(self._export_label, zip(
case_ids, repeat(num_classes), repeat(label_target_dir)))
assert len(get_case_ids_from_dir(
label_target_dir, remove_modality=False, pattern="*.nii.gz")) == len(case_ids)
def _export_label(self, cid: str, num_classes: int, target_dir: Path):
logger.info(f"Processing {cid}")
meta = load_json(self.label_dir / f"{cid}.json")
instance_seg_itk = sitk.ReadImage(str(self.label_dir / f"{cid}.nii.gz"))
instance_seg = sitk.GetArrayFromImage(instance_seg_itk)
if np.any(np.isnan(instance_seg)):
logger.error(f"FOUND NAN IN {cid} LABEL")
# instance classes start form 0 which is background in nnUNet
seg = instances_to_segmentation_np(instance_seg,
meta["instances"],
add_background=True,
)
if num_classes > 0:
assert seg.max() <= num_classes, "Wrong class id, something went wrong."
if instance_seg.max() > 0:
assert seg.max() > 0, "Instance got lost, something went wrong"
assert np.all((instance_seg > 0) == (seg > 0)), "Something wrong with foreground"
assert np.all((instance_seg == 0) == (seg == 0)), "Something wrong with background"
if self.export_stuff:
# map stuff classes to: max(labels) + stuff_cls
stuff_seg = load_sitk_as_array(self.label_dir / f"{cid}_stuff.nii.gz")[0]
for i in range(1, stuff_seg.max() + 1):
seg[stuff_seg == i] = num_classes + i
seg_itk = sitk.GetImageFromArray(seg)
spacing = instance_seg_itk.GetSpacing()
seg_itk.SetSpacing(spacing)
origin = instance_seg_itk.GetOrigin()
seg_itk.SetOrigin(origin)
direction = instance_seg_itk.GetDirection()
seg_itk.SetDirection(direction)
sitk.WriteImage(seg_itk, str(target_dir / f"{cid}.nii.gz"))
def export_dataset_info(self):
"""
Export dataset settings (dataset.json for nnunet)
"""
self.target_dir.mkdir(exist_ok=True, parents=True)
dataset_info = {}
dataset_info["name"] = self.data_info.get("name", "unknown")
dataset_info["description"] = self.data_info.get("description", "unknown")
dataset_info["reference"] = self.data_info.get("reference", "unknown")
dataset_info["licence"] = self.data_info.get("licence", "unknown")
dataset_info["release"] = self.data_info.get("release", "unknown")
min_size = self.data_info.get("min_size", 0)
min_vol = self.data_info.get("min_vol", 0)
dataset_info["prep_info"] = f"min size: {min_size} ; min vol {min_vol}"
dataset_info["tensorImageSize"] = f"{self.data_info.get('dim', 3)}D"
# dataset_info["tensorImageSize"] = self.data_info.get("tensorImageSize", "4D")
dataset_info["modality"] = self.data_info.get("modalities", {})
if not dataset_info["modality"]:
logger.error("Did not find any modalities for dataset")
# +1 for seg classes because of background
dataset_info["labels"] = {"0": "background"}
instance_classes = self.data_info.get("labels", {})
if not instance_classes:
logger.error("Did not find any labels of dataset")
for _id, _class in instance_classes.items():
seg_id = int(_id) + 1
dataset_info["labels"][str(seg_id)] = _class
if self.export_stuff:
stuff_classes = self.data_info.get("labels_stuff", {})
num_instance_classes = len(instance_classes)
# copy stuff classes into nnuent dataset.json
stuff_classes = {
str(int(key) + num_instance_classes): item
for key, item in stuff_classes.items() if int(key) > 0
}
dataset_info["labels_stuff"] = stuff_classes
dataset_info["labels"].update(stuff_classes)
_case_ids = get_case_ids_from_dir(self.label_dir, remove_modality=False)
case_ids_tr = get_case_ids_from_dir(self.tr_image_dir, remove_modality=True)
assert len(set(_case_ids).union(case_ids_tr)) == len(_case_ids), "All training images need a label"
dataset_info["numTraining"] = len(case_ids_tr)
dataset_info["training"] = [
{"image": f"./imagesTr/{cid}.nii.gz", "label": f"./labelsTr/{cid}.nii.gz"}
for cid in case_ids_tr]
if self.ts_image_dir is not None:
case_ids_ts = get_case_ids_from_dir(self.ts_image_dir, remove_modality=True)
dataset_info["numTest"] = len(case_ids_ts)
dataset_info["test"] = [f"./imagesTs/{cid}.nii.gz" for cid in case_ids_ts]
else:
dataset_info["numTest"] = 0
dataset_info["test"] = []
save_json(dataset_info, self.target_dir / "dataset.json")
import inspect
import shutil
import os
from pathlib import Path
from typing import Callable
class Registry:
def __init__(self):
self.mapping = {}
def __getitem__(self, key):
return self.mapping[key]["fn"]
def register(self, fn: Callable):
self._register(fn.__name__, fn, inspect.getfile(fn))
return fn
def _register(self, name: str, fn: Callable, path: str):
if name in self.mapping:
raise TypeError(f"Name {name} already in registry.")
else:
self.mapping[name] = {"fn": fn, "path": path}
def get(self, name: str):
return self.mapping[name]["fn"]
def copy_registered(self, target: Path):
if not target.is_dir():
target.mkdir(parents=True)
paths = [e["path"] for e in self.mapping.values()]
paths = list(set(paths))
names = [p.split('nndet')[-1] for p in paths]
names = [n.replace(os.sep, '_').rsplit('.', 1)[0] for n in names]
names = [f"{n[1:]}.py" for n in names]
for name, path in zip(names, paths):
shutil.copy(path, str(target / name))
from collections import defaultdict
import torch
import re
import numpy as np
from torch import Tensor
from torch._six import container_abcs, string_classes, int_classes
from typing import Sequence, Union, Any, Mapping, Callable, List
np_str_obj_array_pattern = re.compile(r'[SaUO]')
def make_onehot_batch(labels: torch.Tensor, n_classes: torch.Tensor) -> torch.Tensor:
"""
Create onehot encoding of labels
Args:
labels: label tensor to enode [N, dims]
n_classes: number of classes
Returns:
Tensor: onehot encoded tensor [N, C, dims]; N: batch size,
C: number of classes, dims: spatial dimensions
"""
idx = labels.to(dtype=torch.long)
new_shape = [labels.shape[0], n_classes, *labels.shape[1:]]
labels_onehot = torch.zeros(*new_shape, device=labels.device,
dtype=labels.dtype)
labels_onehot.scatter_(1, idx.unsqueeze(dim=1), 1)
return labels_onehot
def to_dtype(inp: Any, dtype: Callable) -> Any:
"""
helper function to convert a sequence of arguments to a specific type
Args:
inp (Any): any object which can be converted by dtype, if sequence is detected
dtype is applied to individual arguments
dtype (Callable): callable to convert arguments
Returns:
Any: converted input
"""
if isinstance(inp, Sequence):
return type(inp)([dtype(i) for i in inp])
else:
return dtype(inp)
def to_device(inp: Union[Sequence[torch.Tensor], torch.Tensor, Mapping[str, torch.Tensor]],
device: Union[torch.device, str], detach: bool = False,
**kwargs) -> \
Union[Sequence[torch.Tensor], torch.Tensor, Mapping[str, torch.Tensor]]:
"""
Push tensor or sequence of tensors to device
Args:
inp (Union[Sequence[torch.Tensor], torch.Tensor]): tensor or sequence of tensors
device (Union[torch.device, str]): target device
detach: detach tensor before moving it to new device
**kwargs: keyword arguments passed to `to` function
Returns:
Union[Sequence[torch.Tensor], torch.Tensor]: tensor or seq. of tenors at target device
"""
if isinstance(inp, torch.Tensor):
if detach:
return inp.detach().to(device=device, **kwargs)
else:
return inp.to(device=device, **kwargs)
elif isinstance(inp, Sequence):
old_type = type(inp)
return old_type([to_device(i, device=device, detach=detach, **kwargs)
for i in inp])
elif isinstance(inp, Mapping):
old_type = type(inp)
return old_type({key: to_device(item, device=device, detach=detach, **kwargs)
for key, item in inp.items()})
else:
return inp
def to_numpy(inp: Union[Sequence[torch.Tensor], torch.Tensor, Any]) -> \
Union[Sequence[np.ndarray], np.ndarray, Any]:
"""
Convert a tensor or sequence of tensors to numpy array/s
Args:
inp (Union[Sequence[torch.Tensor], torch.Tensor]): tensor or sequence of tensors
Returns:
Union[Sequence[np.ndarray], np.ndarray]: array or seq. of arrays at target device
(non tensor entries are forwarded as they are)
"""
if isinstance(inp, (tuple, list)):
old_type = type(inp)
return old_type([to_numpy(i) for i in inp])
elif isinstance(inp, dict) and not isinstance(inp, defaultdict):
old_type = type(inp)
return old_type({k: to_numpy(i) for k, i in inp.items()})
elif isinstance(inp, torch.Tensor):
return inp.detach().cpu().numpy()
else:
return inp
def to_tensor(inp: Any) -> Any:
"""
Convert arrays, seq, mappings to torch tensor
https://github.com/pytorch/pytorch/blob/f522bde1213fcc46f2857f79be7b1b01ddf302a6/torch/utils/data/_utils/collate.py
Args:
inp: np.ndarrays, mappings, sequences are converted to tensors,
rest is passed through this function
Returns:
Any: converted data
"""
elem_type = type(inp)
if isinstance(inp, torch.Tensor):
return inp
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
# array of string classes and object
if elem_type.__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(inp.dtype.str) is not None:
return inp
return torch.as_tensor(inp)
elif isinstance(inp, container_abcs.Mapping):
return {key: to_tensor(inp[key]) for key in inp}
elif isinstance(inp, tuple) and hasattr(inp, '_fields'): # namedtuple
return elem_type(*(to_tensor(d) for d in inp))
elif isinstance(inp, container_abcs.Sequence) and not isinstance(inp, string_classes):
return [to_tensor(d) for d in inp]
else:
return inp
def cat(t: Union[List[Tensor], Tensor], *args, **kwrags):
if not isinstance(t, (list, Tensor)):
raise ValueError(f"Can only concatenate lists and tensors.")
if isinstance(t, Tensor):
return t
elif len(t) == 1:
return t[0]
else:
return torch.cat(t, *args, **kwrags)
import time
from loguru import logger
class Timer:
def __init__(self, msg: str = "", verbose: bool = True):
self.verbose = verbose
self.msg = msg
self.tic: float = None
self.toc: float = None
self.dif: float = None
def __enter__(self):
self.tic = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
self.toc = time.perf_counter()
self.dif = self.toc - self.tic
if self.verbose:
logger.info(f"Operation '{self.msg}' took: {self.toc - self.tic} sec")
# Decathlon
**Disclaimer**: We are not the host of the data.
Please make sure to read the requirements and usage policies of the data and **give credit to the authors of the dataset**!
Please read the information from the homepage carefully and follow the rules and instructions provided by the original authors when using the data.
- Homepage: http://medicaldecathlon.com/
## Setup
0. Follow the installation instructions of nnDetection and create the data directories for the intended tasks, e.g. `Task003_Liver`.
1. Follow the instructions and usage policies to download the data and place the images, labels and dataset.json files inside the raw folder of the respective tasks, e.g. imagesTr -> `Task003_Liver / raw / imagesTr`, labelsTr -> `Task003_Liver / raw / labelsTr` and dataset.json -> `Task003_Liver / raw / dataset.json`
2. Run `python prepare.py [tasks]` in `projects / Task001_Decathlion / scripts` of the nnDetection repository, e.g. to prepare all tasks: `python prepare.py Task003_Liver Task007_Pancreas Task008_HepaticVessel Task010_Colon`
3. Download labels from [here](https://zenodo.org/record/4876497#.YLSudzYzYeY) and replace `labelsTr` / `labelsTs` in the splitted folder with the downloaded ones.
The data is now converted to the correct format and the instructions from the nnDetection README can be used to train the networks.
import argparse
import os
import shutil
import sys
from itertools import repeat
from multiprocessing import Pool, Value
from pathlib import Path
from loguru import logger
from nndet.io.load import save_json
from nndet.io.prepare import maybe_split_4d_nifti, create_test_split
from nndet.io import get_case_ids_from_dir, load_json, save_yaml
from nndet.utils.check import env_guard
from nndet.utils.info import maybe_verbose_iterable
def process_case(case_id,
source_images,
source_labels,
target_images,
target_labels,
):
logger.info(f"Processing case {case_id}")
maybe_split_4d_nifti(source_images / f"{case_id}.nii.gz", target_images)
shutil.copy2(source_labels / f"{case_id}.nii.gz", target_labels)
@env_guard
def main():
parser = argparse.ArgumentParser()
parser.add_argument('tasks', type=str, nargs='+',
help="One or multiple of: Task003_Liver, Task007_Pancreas, "
"Task008_HepaticVessel, Task010_Colon",
)
args = parser.parse_args()
tasks = args.tasks
decathlon_props = {
"Task003_Liver": {
"seg2det_stuff": [1, ], # liver
"seg2det_things": [2, ], # cancer
"min_size": 3.,
"labels": {"0": "cancer"},
"labels_stuff": {"1": "liver"},
},
"Task007_Pancreas": {
"seg2det_stuff": [1, ], # pancreas
"seg2det_things": [2, ],
"min_size": 3.,
"labels": {"0": "cancer"},
"labels_stuff": {"1": "pancreas"},
},
"Task008_HepaticVessel": {
"seg2det_stuff": [1, ], # vessel
"seg2det_things": [2, ],
"min_size": 3.,
"labels": {"0": "tumour"},
"labels_stuff": {"1": "vessel"},
},
"Task010_Colon": {
"seg2det_stuff": [],
"seg2det_things": [1, ],
"min_size": 3.,
"labels": {"0": "cancer"},
"labels_stuff": {},
},
}
basedir = Path(os.getenv('det_data'))
for task in tasks:
task_data_dir = basedir / task
logger.remove()
logger.add(sys.stdout, level="INFO")
logger.add(task_data_dir / "prepare.log", level="DEBUG")
logger.info(f"Preparing task: {task}")
source_raw_dir = task_data_dir / "raw"
source_data_dir = source_raw_dir / "imagesTr"
source_labels_dir = source_raw_dir / "labelsTr"
splitted_dir = task_data_dir / "raw_splitted"
if not source_data_dir.is_dir():
raise ValueError(f"Exptected training images at {source_data_dir}")
if not source_labels_dir.is_dir():
raise ValueError(f"Exptected training labels at {source_labels_dir}")
if not (p := source_raw_dir / "dataset.json").is_file():
raise ValueError(f"Expected dataset json to be located at {p}")
target_data_dir = splitted_dir / "imagesTr"
target_label_dir = splitted_dir / "labelsTr"
target_data_dir.mkdir(parents=True, exist_ok=True)
target_label_dir.mkdir(parents=True, exist_ok=True)
# preapre meta
original_meta = load_json(source_raw_dir / "dataset.json")
dataset_info = {
"task": task,
"name": original_meta["name"],
"target_class": None,
"test_labels": True,
"modalities": original_meta["modality"],
"dim": 3,
"info": {
"original_labels": original_meta["labels"],
"original_numTraining": original_meta["numTraining"],
},
}
dataset_info.update(decathlon_props[task])
save_json(dataset_info, task_data_dir / "dataset.json")
# prepare data and labels
case_ids = get_case_ids_from_dir(source_data_dir, remove_modality=False)
case_ids = sorted([c for c in case_ids if c])
logger.info(f"Found {len(case_ids)} for preparation.")
for cid in maybe_verbose_iterable(case_ids):
process_case(cid,
source_data_dir,
source_labels_dir,
target_data_dir,
target_label_dir,
)
# with Pool(processes=6) as p:
# p.starmap(process_case, zip(case_ids,
# repeat(source_images),
# repeat(source_labels),
# repeat(target_images),
# repeat(target_labels),
# ))
# create an artificial test split
create_test_split(splitted_dir=splitted_dir,
num_modalities=1,
test_size=0.3,
random_state=0,
shuffle=True,
)
if __name__ == '__main__':
main()
# Kits
**Disclaimer**: We are not the host of the data.
Please make sure to read the requirements and usage policies of the data and **give credit to the authors of the dataset**!
Please read the information from the homepage carefully and follow the rules and instructions provided by the original authors when using the data.
- Homepage: https://kits19.grand-challenge.org/data/
## Setup
0. Follow the installation instructions of nnDetection and create a data directory name `Task011_Kits`.
1. Follow the instructions and usage policies to download the data and place all the folders which contain the data and labels for each case into `Task011_Kits / raw`
2. Run `python prepare.py` in `projects / Task011_Kits / scripts` of the nnDetection repository.
3. Remove cases 15, 37, 23, 68, 125, 133 (taken from nnU-Net paper)
4. Download labels from [here](https://zenodo.org/record/4876472#.YLSv7TYzYeY) and replace `labelsTr` in the splitted folder with the downloaded ones.
The data is now converted to the correct format and the instructions from the nnDetection README can be used to train the networks.
import shutil
import os
import sys
from pathlib import Path
from loguru import logger
from nndet.io import save_json
from nndet.io.prepare import create_test_split
from nndet.utils.check import env_guard
from nndet.utils.info import maybe_verbose_iterable
@env_guard
def main():
det_data_dir = Path(os.getenv('det_data'))
task_data_dir = det_data_dir / "Task011_Kits"
source_data_dir = task_data_dir / "raw"
if not source_data_dir.is_dir():
raise RuntimeError(f"{source_data_dir} should contain the raw data but does not exist.")
splitted_dir = task_data_dir / "raw_splitted"
target_data_dir = task_data_dir / "raw_splitted" / "imagesTr"
target_data_dir.mkdir(exist_ok=True, parents=True)
target_label_dir = task_data_dir / "raw_splitted" / "labelsTr"
target_label_dir.mkdir(exist_ok=True, parents=True)
logger.remove()
logger.add(sys.stdout, level="INFO")
logger.add(task_data_dir / "prepare.log", level="DEBUG")
# save meta info
dataset_info = {
"name": "Kits",
"task": "Task011_Kits",
"target_class": None,
"test_labels": True,
"seg2det_stuff": [1,], # define stuff classes: kidney
"seg2det_things": [2,], # define things classes: tumor
"min_size": 3.,
"labels": {"0": "lesion"},
"labels_stuff": {"1": "kidney"},
"modalities": {"0": "CT"},
"dim": 3,
}
save_json(dataset_info, task_data_dir / "dataset.json")
# prepare cases
cases = [str(c.name) for c in source_data_dir.iterdir() if c.is_dir()]
for c in maybe_verbose_iterable(cases):
logger.info(f"Copy case {c}")
case_id = int(c.split("_")[-1])
if case_id < 210:
shutil.copy(source_data_dir / c / "imaging.nii.gz", target_data_dir / f"{c}_0000.nii.gz")
shutil.copy(source_data_dir / c / "segmentation.nii.gz", target_label_dir / f"{c}.nii.gz")
# create an artificial test split
create_test_split(splitted_dir=splitted_dir,
num_modalities=1,
test_size=0.3,
random_state=0,
shuffle=True,
)
if __name__ == '__main__':
main()
# LIDC
**Disclaimer**: We are not the host of the data.
Please make sure to read the requirements and usage policies of the data and **give credit to the authors of the dataset**!
Please read the information from the homepage carefully and follow the rules and instructions provided by the original authors when using the data.
- Homepage: https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI
## Setup MIC LIDC Data preprocessing
0. Follow https://github.com/MIC-DKFZ/LIDC-IDRI-processing to convert the LIDC data into a simpler format.
1. Follow the installation instructions of nnDetection and create a data directory name `Task012_LIDC`.
2. Place the `data_nrrd` folder and `characteristics.csv` into `Task012_LIDC / raw`
3. Run `python prepare_mic.py` in `projects / Task012_LIDC / scripts` of the nnDetection repository.
4. Copy the `splits_final.pkl` from `projects / Task012_LIDC` into the preprocessed folder of the data (if the preprocessing wasn't run until now, it is nesseary to manually create the folder)
The data is now converted to the correct format and the instructions from the nnDetection README can be used to train the networks.
## Setup PyLIDC
**Coming Soon**
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment