Commit 7246044d authored by mibaumgartner's avatar mibaumgartner
Browse files

Merge remote-tracking branch 'origin/master' into main

parents fcec502f 6f4c3333
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import os
from pathlib import Path
from hydra.experimental import initialize_config_module
from nndet.utils.config import compose
if __name__ == '__main__':
"""
Automatically deletes files generated by seg2det and restores
the orignal segmentations
"""
parser = argparse.ArgumentParser()
parser.add_argument('tasks', type=str, nargs='+',
help="Single or multiple task identifiers to process consecutively",
)
args = parser.parse_args()
tasks = args.tasks
initialize_config_module(config_module="nndet.conf")
for task in tasks:
cfg = compose(task, "config.yaml", overrides=[])
print(cfg.pretty())
splitted_dir = Path(cfg["host"]["splitted_4d_output_dir"])
for postfix in ["Tr", "Ts"]:
if (p := splitted_dir / f"labels{postfix}").is_dir():
# delete everything except original files
for f in p.iterdir():
if f.is_file() and not str(f).endswith("_orig.nii.gz"):
os.remove(f)
# rename files
for f in p.glob("*.nii.gz"):
os.rename(f, f.parent / f"{f.name.rsplit('_', 1)[0]}.nii.gz")
else:
print(f"{p} is not a dir. Skipping.")
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import random
import argparse
from pathlib import Path
from multiprocessing import Pool
from itertools import repeat
import numpy as np
import SimpleITK as sitk
from loguru import logger
from nndet.io import save_json
from nndet.utils.check import env_guard
# # 2D example
# [Ignore, Not supported]
# dim = 2
# image_size = [512, 512]
# object_size = [32, 64]
# object_width = 6
# num_images_tr = 100
# num_images_ts = 100
# 3D example
dim = 3
image_size = [256, 256, 256]
object_size = [16, 32]
object_width = 4
def generate_image(image_dir, label_dir, idx):
random.seed(idx)
np.random.seed(idx)
logger.info(f"Generating case_{idx}")
selected_size = np.random.randint(object_size[0], object_size[1])
selected_class = np.random.randint(0, 2)
data = np.random.rand(*image_size)
mask = np.zeros_like(data)
top_left = [np.random.randint(0, image_size[i] - selected_size) for i in range(dim)]
if selected_class == 0:
slicing = tuple([slice(tp, tp + selected_size) for tp in top_left])
data[slicing] = data[slicing] + 0.4
data = data.clip(0, 1)
mask[slicing] = 1
elif selected_class == 1:
slicing = tuple([slice(tp, tp + selected_size) for tp in top_left])
inner_slicing = [slice(tp + object_width, tp + selected_size - object_width) for tp in top_left]
if len(inner_slicing) == 3:
inner_slicing[0] = slice(0, image_size[0])
inner_slicing = tuple(inner_slicing)
object_mask = np.zeros_like(mask).astype(bool)
object_mask[slicing] = 1
object_mask[inner_slicing] = 0
data[object_mask] = data[object_mask] + 0.4
data = data.clip(0, 1)
mask[object_mask] = 1
else:
raise NotImplementedError
if dim == 2:
data = data[None]
mask = mask[None]
data_itk = sitk.GetImageFromArray(data)
mask_itk = sitk.GetImageFromArray(mask)
mask_meta = {
"instances": {
"1": selected_class
},
}
sitk.WriteImage(data_itk, str(image_dir / f"case_{idx}_0000.nii.gz"))
sitk.WriteImage(mask_itk, str(label_dir / f"case_{idx}.nii.gz"))
save_json(mask_meta, label_dir / f"case_{idx}.json")
@env_guard
def main():
"""
Generate an example dataset for nnDetection to test the installation or
experiment with ideas.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--full',
help="Increase size of dataset. "
"Default sizes train/test 10/10 and full 1000/1000.",
action='store_true',
)
parser.add_argument(
'--num_processes',
help="Use multiprocessing to create dataset.",
type=int,
default=0,
)
args = parser.parse_args()
full = args.full
num_processes = args.num_processes
num_images_tr = 1000 if full else 10
num_images_ts = 1000 if full else 10
meta = {
"task": f"Task000D{dim}_Example",
"name": "Example",
"target_class": None,
"test_labels": True,
"labels": {"0": "Square", "1": "SquareHole"},
"modalities": {"0": "MRI"},
"dim": dim,
}
# setup paths
data_task_dir = Path(os.getenv("det_data")) / meta["task"]
data_task_dir.mkdir(parents=True, exist_ok=True)
save_json(meta, data_task_dir / "dataset.json")
raw_splitted_dir = data_task_dir / "raw_splitted"
images_tr_dir = raw_splitted_dir / "imagesTr"
images_tr_dir.mkdir(parents=True, exist_ok=True)
labels_tr_dir = raw_splitted_dir / "labelsTr"
labels_tr_dir.mkdir(parents=True, exist_ok=True)
images_ts_dir = raw_splitted_dir / "imagesTs"
images_ts_dir.mkdir(parents=True, exist_ok=True)
labels_ts_dir = raw_splitted_dir / "labelsTs"
labels_ts_dir.mkdir(parents=True, exist_ok=True)
if num_processes == 0:
for idx in range(num_images_tr):
generate_image(
images_tr_dir,
labels_tr_dir,
idx,
)
for idx in range(num_images_tr, num_images_tr + num_images_ts):
generate_image(
images_ts_dir,
labels_ts_dir,
idx,
)
else:
logger.info("Using multiprocessing to create example dataset.")
with Pool(processes=num_processes) as p:
p.starmap(
generate_image,
zip(
repeat(images_tr_dir),
repeat(labels_tr_dir),
range(num_images_tr),
)
)
with Pool(processes=num_processes) as p:
p.starmap(
generate_image,
zip(
repeat(images_ts_dir),
repeat(labels_ts_dir),
range(num_images_tr, num_images_tr + num_images_ts),
)
)
if __name__ == '__main__':
main()
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import copy
import os
import argparse
import sys
from pathlib import Path
from loguru import logger
from omegaconf import OmegaConf
from hydra.experimental import initialize_config_module
from nnunet.paths import nnUNet_raw_data
from nndet.io import get_task
from nndet.utils.config import compose
from nndet.utils.nnunet import Exporter
def run(cfg, target_dir, stuff: bool):
base_dir = Path(cfg.host.splitted_4d_output_dir)
target_dir.mkdir(exist_ok=True, parents=True)
if (base_dir / "imagesTs").is_dir():
logger.info("Found test images and will export them too")
ts_image_dir = base_dir / "imagesTs"
else:
ts_image_dir = None
exporter = Exporter(data_info=OmegaConf.to_container(cfg.data),
tr_image_dir=base_dir / "imagesTr",
ts_image_dir=ts_image_dir,
label_dir=base_dir / "labelsTr",
target_dir=target_dir,
export_stuff=stuff,
).export()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('tasks', type=str, nargs='+',
help="Single or multiple task identifiers to process consecutively",
)
parser.add_argument('-nt', '--new_tasks', type=str, nargs='+',
help="Rename the tasks.",
required=False, default=None,
)
parser.add_argument("--stuff", action='store_true',
help="Export stuff and things classes."
"The final detection evaluation will be performed on things classes only.")
parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file",
required=False,
)
args = parser.parse_args()
tasks = args.tasks
new_tasks = args.new_tasks
ov = args.overwrites
stuff = args.stuff
print(f"Overwrites: {ov}")
initialize_config_module(config_module="nndet.conf")
if new_tasks is None:
new_tasks = tasks
for task, new_task in zip(tasks, new_tasks):
task = get_task(task, name=True)
if nnUNet_raw_data is None:
raise RuntimeError(f"Please set `nnUNet_raw_data` for nnUNet!")
target_dir = Path(nnUNet_raw_data) / new_task
logger.remove()
logger.add(sys.stdout, level="INFO")
logger.add(target_dir / "nnunet_export.log", level="DEBUG")
_ov = copy.deepcopy(ov) if ov is not None else []
cfg = compose(task, "config.yaml", overrides=ov if ov is not None else [])
print(cfg.pretty)
run(cfg, target_dir, stuff=stuff)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
from nndet.io.load import save_json
import os
import sys
import shutil
from functools import partial
from itertools import repeat
from multiprocessing import Pool
from pathlib import Path, PurePath
from typing import Union, Sequence, Optional
import numpy as np
from hydra.experimental import initialize_config_module
from loguru import logger
from nndet.evaluator.registry import evaluate_box_dir
from nndet.io import load_pickle, save_pickle, get_task, load_json
from nndet.utils.clustering import instance_results_from_seg
from nndet.utils.config import compose
from nndet.utils.info import maybe_verbose_iterable
Pathlike = Union[str, Path]
TARGET_METRIC = "mAP_IoU_0.10_0.50_0.05_MaxDet_100"
def import_nnunet_boxes(
# settings
nnunet_prediction_dir: Pathlike,
save_dir: Pathlike,
boxes_gt_dir: Pathlike,
classes: Sequence[str],
stuff: Optional[Sequence[int]] = None,
num_workers: int = 6,
):
assert nnunet_prediction_dir.is_dir(), f"{nnunet_prediction_dir} is not a dir"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
summary = []
# create sweep dir
sweep_dir = Path(nnunet_prediction_dir)
postprocessing_settings = {}
# optimize min num voxels
logger.info("Looking for optimal min voxel size")
min_num_voxel_settings = [0, 5, 10, 15, 20]
scores = []
for min_num_voxel in min_num_voxel_settings:
# create temp dir
sweep_prediction = sweep_dir / f"sweep_min_voxel{min_num_voxel}"
sweep_prediction.mkdir(parents=True)
# import with settings
import_dir(
nnunet_prediction_dir=nnunet_prediction_dir,
target_dir=sweep_prediction,
min_num_voxel=min_num_voxel,
save_seg=False,
save_iseg=False,
stuff=stuff,
num_workers=num_workers,
)
# evaluate
_scores, _ = evaluate_box_dir(
pred_dir=sweep_prediction,
gt_dir=boxes_gt_dir,
classes=classes,
save_dir=None,
)
scores.append(_scores[TARGET_METRIC])
summary.append({f"Min voxel {min_num_voxel}": _scores[TARGET_METRIC]})
logger.info(f"Min voxel {min_num_voxel} :: {_scores[TARGET_METRIC]}")
shutil.rmtree(sweep_prediction)
idx = int(np.argmax(scores))
postprocessing_settings["min_num_voxel"] = min_num_voxel_settings[idx]
logger.info(f"Found min num voxel {min_num_voxel_settings[idx]} with score {scores[idx]}")
# optimize score threshold
logger.info("Looking for optimal min probability threshold")
min_threshold_settings = [None, 0.1, 0.2, 0.3, 0.4, 0.5]
scores = []
for min_threshold in min_threshold_settings:
# create temp dir
sweep_prediction = sweep_dir / f"sweep_min_threshold_{min_threshold}"
sweep_prediction.mkdir(parents=True)
# import with settings
import_dir(
nnunet_prediction_dir=nnunet_prediction_dir,
target_dir=sweep_prediction,
min_threshold=min_threshold,
save_seg=False,
save_iseg=False,
stuff=stuff,
num_workers=num_workers,
**postprocessing_settings,
)
# evaluate
_scores, _ = evaluate_box_dir(
pred_dir=sweep_prediction,
gt_dir=boxes_gt_dir,
classes=classes,
save_dir=None,
)
scores.append(_scores[TARGET_METRIC])
summary.append({f"Min score {min_threshold}": _scores[TARGET_METRIC]})
logger.info(f"Min score {min_threshold} :: {_scores[TARGET_METRIC]}")
shutil.rmtree(sweep_prediction)
idx = int(np.argmax(scores))
postprocessing_settings["min_threshold"] = min_threshold_settings[idx]
logger.info(f"Found min threshold {min_threshold_settings[idx]} with score {scores[idx]}")
logger.info("Looking for best probability aggregation")
aggreagtion_settings = ["max", "median", "mean", "percentile95"]
scores = []
for aggregation in aggreagtion_settings:
# create temp dir
sweep_prediction = sweep_dir / f"sweep_aggregation_{aggregation}"
sweep_prediction.mkdir(parents=True)
# import with settings
import_dir(
nnunet_prediction_dir=nnunet_prediction_dir,
target_dir=sweep_prediction,
aggregation=aggregation,
save_seg=False,
save_iseg=False,
stuff=stuff,
num_workers=num_workers,
**postprocessing_settings,
)
# evaluate
_scores, _ = evaluate_box_dir(
pred_dir=sweep_prediction,
gt_dir=boxes_gt_dir,
classes=classes,
save_dir=None,
)
scores.append(_scores[TARGET_METRIC])
summary.append({f"Aggreagtion {aggregation}": _scores[TARGET_METRIC]})
logger.info(f"Aggreagtion {aggregation} :: {_scores[TARGET_METRIC]}")
shutil.rmtree(sweep_prediction)
idx = int(np.argmax(scores))
postprocessing_settings["aggregation"] = aggreagtion_settings[idx]
logger.info(f"Found aggregation {aggreagtion_settings[idx]} with score {scores[idx]}")
save_pickle(postprocessing_settings, save_dir / "postprocessing.pkl")
save_json(summary, save_dir / "summary.json")
return postprocessing_settings
def import_dir(
nnunet_prediction_dir: Pathlike,
target_dir: Optional[Pathlike] = None,
aggregation="max",
min_num_voxel=0,
min_threshold=None,
save_seg: bool = True,
save_iseg: bool = True,
stuff: Optional[Sequence[int]] = None,
num_workers: int = 6,
):
source = [f for f in nnunet_prediction_dir.iterdir() if f.suffix == ".npz"]
_fn = partial(import_single_case,
aggregation=aggregation,
min_num_voxel=min_num_voxel,
min_threshold=min_threshold,
save_seg=save_seg,
save_iseg=save_iseg,
stuff=stuff,
)
# for s in maybe_verbose_iterable(source):
# _fn(s, target_dir)
with Pool(processes=num_workers) as p:
p.starmap(_fn, zip(source, repeat(target_dir)))
def import_single_case(logits_source: Path,
logits_target_dir: Optional[Path],
aggregation: str,
min_num_voxel: int,
min_threshold: Optional[float],
save_seg: bool = True,
save_iseg: bool = True,
stuff: Optional[Sequence[int]] = None,
):
"""
Process a single case
Args:
logits_source: path to nnunet prediction
logits_target_dir: path to dir where result should be saved
aggregation: aggregation method for probabilities.
save_seg: save semantic segmentation
save_iseg: save instance segmentation
stuff: stuff classes to remove
"""
assert logits_source.is_file(), f"Logits source needs to be a file, found {logits_source}"
assert logits_target_dir.is_dir(), f"Logits target dir needs to be a dir, found {logits_target_dir}"
case_name = logits_source.stem
logger.info(f"Processing {case_name}")
properties_file = logits_source.parent / f"{case_name}.pkl"
probs = np.load(str(logits_source))["softmax"]
if properties_file.is_file():
properties_dict = load_pickle(properties_file)
bbox = properties_dict.get('crop_bbox')
shape_original_before_cropping = properties_dict.get('original_size_of_raw_data')
if bbox is not None:
tmp = np.zeros((probs.shape[0], *shape_original_before_cropping))
for c in range(3):
bbox[c][1] = np.min((bbox[c][0] + probs.shape[c + 1], shape_original_before_cropping[c]))
tmp[:, bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], bbox[2][0]:bbox[2][1]] = probs
probs = tmp
res = instance_results_from_seg(probs,
aggregation=aggregation,
min_num_voxel=min_num_voxel,
min_threshold=min_threshold,
stuff=stuff,
)
detection_target = logits_target_dir / f"{case_name}_boxes.pkl"
segmentation_target = logits_target_dir / f"{case_name}_segmentation.pkl"
instances_target = logits_target_dir / f"{case_name}_instances.pkl"
boxes = {key: res[key] for key in ["pred_boxes", "pred_labels", "pred_scores"]}
save_pickle(boxes, detection_target)
if save_iseg:
instances = {key: res[key] for key in ["pred_instances", "pred_labels", "pred_scores"]}
save_pickle(instances, instances_target)
if save_seg:
segmentation = {"pred_seg": np.argmax(probs, axis=0)}
save_pickle(segmentation, segmentation_target)
def nnunet_dataset_json(nnunet_task: str):
if (p := os.getenv("nnUNet_raw_data_base")) is not None:
search_dir = Path(p) / "nnUNet_raw_data" / nnunet_task
logger.info(f"Looking for dataset.json in {search_dir}")
if (fp := search_dir / "dataset.json").is_file():
return load_json(fp)
elif (p := os.getenv("nnUNet_preprocessed")) is not None:
search_dir = Path(p) / nnunet_task
logger.info(f"Looking for dataset.json in {search_dir}")
if (fp := search_dir / "dataset.json").is_file():
return load_json(fp)
else:
raise ValueError("Was not able to find nnunet dataset.json")
def copy_and_ensemble(cid, nnunet_dirs, nnunet_prediction_dir):
logger.info(f"Copy and ensemble: {cid}")
case = [np.load(_nnunet_dir / f"fold_{fold}" / "validation_raw" / f"{cid}.npz")["softmax"] for _nnunet_dir in nnunet_dirs]
assert len(case) == len(nnunet_dirs)
case_ensemble = np.mean(case, axis=0)
assert case_ensemble.shape == case[0].shape
np.savez_compressed(nnunet_prediction_dir / f"{cid}.npz", softmax=case_ensemble)
def copy_and_ensemble_test(cid, nnunet_dirs, nnunet_prediction_dir):
logger.info(f"Copy and ensemble: {cid}")
case = [np.load(_nnunet_dir / f"{cid}.npz")["softmax"] for _nnunet_dir in nnunet_dirs]
assert len(case) == len(nnunet_dirs)
case_ensemble = np.mean(case, axis=0)
assert case_ensemble.shape == case[0].shape
np.savez_compressed(nnunet_prediction_dir / f"{cid}.npz", softmax=case_ensemble)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--nnunet', type=Path, nargs='+',
help='if val: Path to nnunet dir. e,g. '
'../nnUNet/3d_fullres/TaskX/nnUNetTrainerV2__nnUNetPlansv2.1 '
'if test: path to prediction dirs to ensemble. Val mode needed to be run before!',
required=True,
)
parser.add_argument('-m', '--mode', type=str, required=True,
help="Provide operation mode. 'val' will ensemble and run "
"empirical optimization. 'test' will load settings and postprocess.")
parser.add_argument('-t', '--task', type=str, default=None,
help="detection task id, needed to determine stuff classes"
"If it is not provided via an argument the script tries to determine "
"it from the nnunet path, this works only if the task names are identical!"
"Need to provide task id in test mode!",
required=False,
)
parser.add_argument('-pf', '--prefix', type=str, default='val',
help="Prefix for folder. One of 'val', 'test'",
required=False,
)
parser.add_argument('--num_workers', type=int, default=6,
help="Number of worker to use",
required=False,
)
parser.add_argument('--simple', action='store_true',
help="Argmax with max probability aggregation.",
)
# Evaluation related settings
parser.add_argument('--save_seg', help="Save semantic segmentation", action='store_true')
parser.add_argument('--save_iseg', help="Save instance segmentation", action='store_true')
args = parser.parse_args()
nnunet_dirs = args.nnunet
task = args.task
prefix = args.prefix
mode = args.mode
num_workers = args.num_workers
simple = args.simple
save_seg = args.save_seg
save_iseg = args.save_iseg
# select corresponding nnDetection task
nnunet_dir = nnunet_dirs[0]
task_names = [n for n in PurePath(nnunet_dir).parts if "Task" in n]
if len(task_names) > 1:
logger.error(f"Found multiple task names trying to continue with {task_names[-1]}")
logger.info(f"Found nnunet task {task_names[-1]} in nnunet path")
nnunet_task = task_names[-1]
if task is None:
logger.info(f"Using nnunet task {nnunet_task} as detection task id")
task = nnunet_task
else:
task = get_task(task, name=True)
task_dir = Path(os.getenv("det_models")) / task
initialize_config_module(config_module="nndet.conf")
cfg = compose(task, "config.yaml", overrides=[])
logger.remove()
logger.add(sys.stdout, level="INFO")
log_file = task_dir / "nnUNet" / "import.log"
logger.add(log_file, level="INFO")
if simple:
nndet_unet_dir = task_dir / "nnUNet_Simple" / "consolidated"
else:
nndet_unet_dir = task_dir / "nnUNet" / "consolidated"
instance_classes = cfg["data"]["labels"]
stuff_classes = cfg.get("labels_stuff", {})
num_instance_classes = len(instance_classes)
stuff_classes = {
str(int(key) + num_instance_classes): item
for key, item in stuff_classes.items() if int(key) > 0
}
stuff = [int(s) for s in stuff_classes.keys()]
if mode.lower() == "val":
nnunet_prediction_dir = nndet_unet_dir /f"validation_raw_all"
nnunet_prediction_dir.mkdir(parents=True, exist_ok=True)
# copy all predictions from nnunet into one directory
for fold in range(5):
case_ids = [p.stem for p in (nnunet_dir / f"fold_{fold}" / "validation_raw").iterdir() if p.name.endswith(".npz")]
logger.info(f"Copy and ensemble results fold {fold} with {len(case_ids)} cases.")
# copy properties
for p in [p for p in (nnunet_dir / f"fold_{fold}" / "validation_raw").iterdir() if p.name.endswith(".pkl")]:
shutil.copyfile(p, nnunet_prediction_dir / p.name)
if num_workers > 0:
with Pool(processes=max(num_workers // 4, 1)) as p:
p.starmap(copy_and_ensemble,
zip(case_ids,
repeat(nnunet_dirs),
repeat(nnunet_prediction_dir),
))
else:
for cid in case_ids:
copy_and_ensemble(cid, nnunet_dirs, nnunet_prediction_dir)
if simple:
postprocessing_settings = {
"aggregation": "max",
"min_num_voxel": 5,
"min_threshold": None,
}
save_pickle(postprocessing_settings, nndet_unet_dir / "postprocessing.pkl")
else:
postprocessing_settings = import_nnunet_boxes(
nnunet_prediction_dir=nnunet_prediction_dir,
save_dir=nndet_unet_dir,
boxes_gt_dir=Path(os.getenv("det_data")) / task / "preprocessed" / "labelsTr",
classes=list(cfg["data"]["labels"].keys()),
stuff=stuff,
num_workers=num_workers,
)
save_pickle({}, nndet_unet_dir / "plan.pkl")
target_dir = nndet_unet_dir / "val_predictions"
else:
case_ids = [p.stem for p in nnunet_dir.iterdir() if p.name.endswith(".npz")]
nnunet_prediction_dir = nndet_unet_dir /f"test_raw_all"
nnunet_prediction_dir.mkdir(parents=True, exist_ok=True)
if num_workers > 0:
with Pool(processes=max(num_workers // 4, 1)) as p:
p.starmap(copy_and_ensemble_test,
zip(case_ids,
repeat(nnunet_dirs),
repeat(nnunet_prediction_dir),
))
else:
for cid in case_ids:
copy_and_ensemble_test(cid, nnunet_dirs, nnunet_prediction_dir)
postprocessing_settings = load_pickle(nndet_unet_dir / "postprocessing.pkl")
target_dir = nndet_unet_dir / "test_predictions"
logger.info(f"Creating final predictions")
target_dir.mkdir(parents=True, exist_ok=True)
import_dir(
nnunet_prediction_dir=nnunet_prediction_dir,
target_dir=target_dir,
save_seg=save_seg,
save_iseg=save_iseg,
stuff=stuff,
num_workers=num_workers,
**postprocessing_settings,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
from multiprocessing import Value
import os
import sys
from typing import Any, Mapping, Type, TypeVar
from omegaconf import OmegaConf
from loguru import logger
from pathlib import Path
from nndet.utils.check import env_guard
from nndet.planning import PLANNER_REGISTRY
from nndet.io import get_task, get_training_dir
from nndet.io.load import load_pickle
from nndet.inference.loading import load_all_models
from nndet.inference.helper import predict_dir
from nndet.utils.check import check_data_and_label_splitted
def run(cfg: dict,
training_dir: Path,
process: bool = True,
num_models: int = None,
num_tta_transforms: int = None,
test_split: bool = False,
num_processes: int = 3,
):
"""
Run inference pipeline
Args:
cfg: configurations
training_dir: path to model directory
process: preprocess test data
num_models: number of models to use for ensemble; if None all Models
are used
num_tta_transforms: number of tta transformation; if None the maximum
number of transformation is used
test_split: Typical usage of nnDetection will never require
this option! Predict an already preprocessed split of the original
training data. The 'test' split needs to be located in fold 0
of a manually created split file.
"""
plan = load_pickle(training_dir / "plan_inference.pkl")
preprocessed_output_dir = Path(cfg["host"]["preprocessed_output_dir"])
prediction_dir = training_dir / "test_predictions"
logger.remove()
logger.add(sys.stdout, format="{level} {message}", level="INFO")
logger.add(Path(training_dir) / "inference.log", level="INFO")
if process:
planner_cls = PLANNER_REGISTRY.get(plan["planner_id"])
planner_cls.run_preprocessing_test(
preprocessed_output_dir=preprocessed_output_dir,
splitted_4d_output_dir=cfg["host"]["splitted_4d_output_dir"],
plan=plan,
num_processes=num_processes,
)
prediction_dir.mkdir(parents=True, exist_ok=True)
if test_split:
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTr"
case_ids = load_pickle(training_dir / "splits.pkl")[0]["test"]
else:
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTs"
case_ids = None
predict_dir(source_dir=source_dir,
target_dir=prediction_dir,
cfg=cfg,
plan=plan,
source_models=training_dir,
num_models=num_models,
num_tta_transforms=num_tta_transforms,
model_fn=load_all_models,
restore=True,
case_ids=case_ids,
**cfg.get("inference_kwargs", {}),
)
def set_arg(cfg: Mapping, key: str, val: Any, force_args: bool) -> Mapping:
"""
Check if value of config and given key match and handle approriately:
If values match no action will be performend.
If the values do not match and force_args is activated the value
in the config will be overwritten.
if the values do not match and force args is deactivatd a ValueError
will be raised.
Args:
cfg: config to check and write values to
key: key to check.
val: Potentially new value.
force_args: Enable if config value should be overwritten if values do
not match.
Returns:
Type[dict]: config with potentially changed key
"""
if key not in cfg:
raise ValueError(f"{key} is not in config.")
if cfg[key] != val:
if force_args:
logger.warning(f"Found different values for {key}, will overwrite {cfg[key]} with {val}")
cfg[key] = val
else:
raise ValueError(f"Found different values for {key} and overwrite disabled."
f"Found {cfg[key]} but expected {val}.")
return cfg
@env_guard
def main():
parser = argparse.ArgumentParser()
parser.add_argument('task', type=str, help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
parser.add_argument('model', type=str, help="model name, e.g. RetinaUNetV0")
parser.add_argument('-f', '--fold', type=int, required=False, default=-1,
help="fold to use for prediction. -1 uses the consolidated model",
)
parser.add_argument('-nmodels', '--num_models', type=int, default=None,
required=False,
help="number of models for ensemble(per default all models will be used)."
"NOT usable by default -- will use all models inside the folder!",
)
parser.add_argument('-ntta', '--num_tta', type=int, default=None,
help="number of tta transforms (per default most tta are chosen)",
required=False,
)
parser.add_argument('-o', '--overwrites', type=str, nargs='+',
default=None,
required=False,
help=("overwrites for config file. "
"inference_kwargs can be used to add additional "
"keyword arguments to inference."),
)
parser.add_argument('--no_preprocess', action='store_false', help="Preprocess test data")
parser.add_argument('--force_args', action='store_true',
help=("When transferring models betweens tasks the name "
"and fold might differ from the original one. "
"This forces an overwrite to the passed in arguments of"
" this function. This can be dangerous!"),
)
parser.add_argument('--test_split', action='store_true',
help=("Typical usage of nnDetection will never require "
"this option! Predict an already preprocessed "
"split of the original training data. "
"The 'test' split needs to be located in fold 0 "
"of a manually created split file."),
)
parser.add_argument('--check',
help="Run check of the test data before predicting",
action='store_true',
)
parser.add_argument('-npp', '--num_processes_preprocessing',
type=int, default=3, required=False,
help="Number of processes to use for resampling.",
)
args = parser.parse_args()
model = args.model
fold = args.fold
task = args.task
num_models = args.num_models
num_tta_transforms = args.num_tta
ov = args.overwrites
force_args = args.force_args
test_split = args.test_split
check = args.check
num_processes = args.num_processes_preprocessing
task_name = get_task(task, name=True)
task_model_dir = Path(os.getenv("det_models"))
training_dir = get_training_dir(task_model_dir / task_name / model, fold)
process = args.no_preprocess
if test_split and process:
raise ValueError("When using the test split option raw data is not "
"supported. Need to add --no_preprocess flag!")
cfg = OmegaConf.load(str(training_dir / "config.yaml"))
cfg = set_arg(cfg, "task", task_name, force_args=force_args)
cfg["exp"] = set_arg(cfg["exp"], "fold", fold,
force_args=True if fold == -1 else force_args)
cfg["exp"] = set_arg(cfg["exp"], "id", model, force_args=force_args)
overwrites = ov if ov is not None else []
overwrites.append("host.parent_data=${env:det_data}")
overwrites.append("host.parent_results=${env:det_models}")
cfg.merge_with_dotlist(overwrites)
if check:
if test_split:
raise ValueError("Check is not supported for test split option.")
check_data_and_label_splitted(
task_name=cfg["task"],
test=True,
labels=False,
full_check=True
)
run(OmegaConf.to_container(cfg, resolve=True),
training_dir,
process=process,
num_models=num_models,
num_tta_transforms=num_tta_transforms,
test_split=test_split,
num_processes=num_processes,
)
if __name__ == '__main__':
main()
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import shutil
import os
import copy
import sys
import traceback
import numpy as np
from loguru import logger
from itertools import repeat
from typing import Dict, Sequence, Tuple, List
from pathlib import Path
from multiprocessing import Pool
from hydra.experimental import initialize_config_module
from omegaconf import OmegaConf
from nndet.utils.config import compose
from nndet.utils.check import env_guard
from nndet.planning import DatasetAnalyzer
from nndet.planning import PLANNER_REGISTRY
from nndet.planning.experiment.utils import create_labels
from nndet.planning.properties.registry import medical_instance_props
from nndet.io.load import load_pickle, load_npz_looped
from nndet.io.prepare import maybe_split_4d_nifti, instances_from_segmentation
from nndet.io.paths import get_paths_raw_to_split, get_paths_from_splitted_dir, subfiles, get_case_id_from_path
from nndet.preprocessing import ImageCropper
from nndet.utils.check import check_dataset_file, check_data_and_label_splitted
def run_splitting_4d(data_dir: Path, output_dir: Path, num_processes: int) -> None:
"""
Due to historical reasons this framework uses 3D niftis instead of 4D niftis
This function splits present 4D niftis into 3D niftis per channel
Args:
data_dir (str): top directory where data is located
output_dir (str): output directory for splitted data
num_processes (int): number of processes to use to split data
rm_classes: classes to remove from segmentation
ro_classes: reorder classes in segmentation
subtract_one_from_classes: subtract one from all classes in mapping
instances_from_seg: converts semantic segmentations to instance
segmentations via connected components
"""
if output_dir.is_dir():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True)
source_files, target_folders = get_paths_raw_to_split(data_dir, output_dir)
with Pool(processes=num_processes) as p:
p.starmap(maybe_split_4d_nifti, zip(source_files, target_folders))
def prepare_labels(data_dir: Path,
output_dir: Path,
num_processes: int,
rm_classes: Sequence[int],
ro_classes: Dict[int, int],
subtract_one_from_classes: bool,
instances_from_seg: bool = True):
"""
Copy labels to splitted dir.
Optionally, runs connected components and removes classes from
semantic segmentations
Args:
data_dir: path to task base dir
output_dir: base dir of prepared labels
num_processes: number of processes to use
rm_classes: remove specific classes from semantic segmentation.
Can only be used with `instances_from_seg`
ro_classes: reorder classes in semantic segmentation for
connected components. Can only be used with `instances_from_seg`
subtract_one_from_classes: class indices for detection start from 0.
Subtracts 1 from classes extracted from segmentation
instances_from_seg: Run connected components. Defaults to True.
"""
for labels_subdir in ("labelsTr", "labelsTs"):
if not (data_dir / labels_subdir).is_dir():
continue
labels_output_dir = output_dir / labels_subdir
if instances_from_seg:
if not labels_output_dir.is_dir():
labels_output_dir.mkdir(parents=True)
with Pool(processes=num_processes) as p:
paths = list(map(Path, subfiles(data_dir / labels_subdir,
identifier="*.nii.gz", join=True)))
paths = [path for path in paths if not path.name.startswith('.')]
p.starmap(instances_from_segmentation, zip(
paths, repeat(labels_output_dir), repeat(rm_classes),
repeat(ro_classes), repeat(subtract_one_from_classes)))
else:
shutil.copytree(data_dir / labels_subdir, labels_output_dir)
def run_cropping_and_convert(cropped_output_dir: Path,
splitted_4d_output_dir: Path,
data_info: dict,
overwrite: bool,
num_processes: int,
):
"""
First preparation step data:
- stack data and segementation to a single sample (segmentation is the last channel)
- save data as npz (format: case_id.npz)
- save additional properties as pkl file (format: case_id.pkl)
- crop data to nonzeor region; crop segmentation; fill segmentation with -1 where in nonzero regions
Args:
cropped_output_dir (Path): path to directory where cropped images should be saved
splitted_4d_output_dir (Path): path to splitted data
data_info: information about data set (here `modalities` is needed)
overwrite (bool): overwrite existing cropped data
num_processes (int): number of processes used to crop image data
"""
num_modalities = len(data_info["modalities"].keys())
if overwrite and cropped_output_dir.is_dir():
shutil.rmtree(str(cropped_output_dir))
if not cropped_output_dir.is_dir():
cropped_output_dir.mkdir(parents=True)
case_files = get_paths_from_splitted_dir(num_modalities, splitted_4d_output_dir)
logger.info(f"Running cropping with overwrite {overwrite}.")
imgcrop = ImageCropper(num_processes, cropped_output_dir)
imgcrop.run_cropping(case_files, overwrite_existing=overwrite)
case_ids_failed, result_check = run_check(cropped_output_dir / "imagesTr",
remove=True,
processes=num_processes,
keys=("data",)
)
if not result_check:
logger.warning(
f"Crop check failed: There are corrupted files!!!! {case_ids_failed}"
f"Try to crop corrupted files again.",
)
imgcrop = ImageCropper(0, cropped_output_dir)
imgcrop.run_cropping(case_files, overwrite_existing=False)
case_ids_failed, result_check = run_check(cropped_output_dir / "imagesTr",
remove=False,
processes=num_processes,
keys=("data",)
)
if not result_check:
logger.error(f"Found corrupted files: {case_ids_failed}.")
raise RuntimeError("Corrupted files")
else:
logger.info(f"Crop check successful: Loading check completed")
def run_dataset_analysis(cropped_output_dir: Path,
preprocessed_output_dir: Path,
data_info: dict,
num_processes: int,
intensity_properties: bool = True,
overwrite: bool = True,
):
"""
Analyse dataset
Args:
cropped_output_dir: path to base cropped dir
preprocessed_output_dir: path to base preprocessed output dir
data_info: additional information about dataset (`modalities` and `labels` needed)
num_processes: number of processes to use
intensity_properties: analyze intensity values of foreground
overwrite: overwrite existing properties
"""
analyzer = DatasetAnalyzer(
cropped_output_dir,
preprocessed_output_dir=preprocessed_output_dir,
data_info=data_info,
num_processes=num_processes,
overwrite=overwrite,
)
properties = medical_instance_props(intensity_properties=intensity_properties)
_ = analyzer.analyze_dataset(properties)
def run_planning_and_process(
splitted_4d_output_dir: Path,
cropped_output_dir: Path,
preprocessed_output_dir: Path,
planner_name: str,
dim: int,
model_name: str,
model_cfg: Dict,
num_processes: int,
run_preprocessing: bool = True,
):
"""
Run planning and preprocessing
Args:
splitted_4d_output_dir: base dir of splitted data
cropped_output_dir: base dir of cropped data
preprocessed_output_dir: base dir of preprocessed data
planner_name: planner name
dim: number of spatial dimensions
model_name: name of model to run planning for
model_cfg: hyperparameters of model (used during planning to
instantiate model)
num_processes: number of processes to use for preprocessing
run_preprocessing: Preprocess and check data. Defaults to True.
"""
planner_cls = PLANNER_REGISTRY.get(planner_name)
planner = planner_cls(
preprocessed_output_dir=preprocessed_output_dir
)
plan_identifiers = planner.plan_experiment(
model_name=model_name,
model_cfg=model_cfg,
)
if run_preprocessing:
for plan_id in plan_identifiers:
plan = load_pickle(preprocessed_output_dir / plan_id)
planner.run_preprocessing(
cropped_data_dir=cropped_output_dir / "imagesTr",
plan=plan,
num_processes=num_processes,
)
case_ids_failed, result_check = run_check(
data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
remove=True,
processes=num_processes
)
# delete and rerun corrupted cases
if not result_check:
logger.warning(f"{plan_id} check failed: There are corrupted files {case_ids_failed}!!!!"
f"Running preprocessing of those cases without multiprocessing.")
planner.run_preprocessing(
cropped_data_dir=cropped_output_dir / "imagesTr",
plan=plan,
num_processes=0,
)
case_ids_failed, result_check = run_check(
data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
remove=False,
processes=0
)
if not result_check:
logger.error(f"Could not fix corrupted files {case_ids_failed}!")
raise RuntimeError("Found corrupted files, check logs!")
else:
logger.info("Fixed corrupted files.")
else:
logger.info(f"{plan_id} check successful: Loading check completed")
if run_preprocessing:
create_labels(
preprocessed_output_dir=preprocessed_output_dir,
source_dir=splitted_4d_output_dir,
num_processes=num_processes,
)
def run_check(data_dir: Path,
remove: bool = False,
processes: int = 8,
keys: Sequence[str] = ("data", "seg"),
) -> Tuple[List[str], bool]:
"""
Check if files from preprocessed dir are loadable
Args:
data_dir (Path): path to preprocessed data
remove (bool, optional): if loading fails the file is the npz and pkl
file are removed automatically. Defaults to False.
processes (int, optional): number of processes to use. If
0 processes are specified it uses a normal for loop. Defaults to 8.
keys: keys to load and check
Returns:
True if all cases were loadable, False otherwise
"""
cases_npz = list(data_dir.glob("*.npz"))
cases_npz.sort()
cases_pkl = [case.parent / f"{(case.name).rsplit('.', 1)[0]}.pkl"
for case in cases_npz]
if processes == 0:
result = [check_case(case_npz, case_pkl, remove=remove)
for case_npz, case_pkl in zip(cases_npz, cases_pkl)]
else:
with Pool(processes=processes) as p:
result = p.starmap(check_case,
zip(cases_npz, cases_pkl, repeat(remove), repeat(keys)))
failed_cases = [fc[0] for fc in result if not fc[1]]
logger.info(f"Checked {len(result)} cases in {data_dir}")
return failed_cases, len(failed_cases) == 0
def check_case(case_npz: Path,
case_pkl: Path = None,
remove: bool = False,
keys: Sequence[str] = ("data", "seg"),
) -> Tuple[str, bool]:
"""
Check if a single cases loadable
Args:
case_npz (Path): path to npz file
case_pkl (Path, optional): path to pkl file. Defaults to None.
remove (bool, optional): if loading fails the file is the npz and pkl
file are removed automatically. Defaults to False.
Returns:
str: case id
bool: true if case was loaded correctly, false otherwise
"""
logger.info(f"Checking {case_npz}")
case_id = get_case_id_from_path(case_npz, remove_modality=False)
try:
case_dict = load_npz_looped(str(case_npz), keys=keys, num_tries=3)
if "seg" in keys and case_pkl is not None:
properties = load_pickle(case_pkl)
seg = case_dict["seg"]
seg_instances = np.unique(seg) # automatically sorted
seg_instances = seg_instances[seg_instances > 0]
instances_properties = properties["instances"].keys()
props_instances = np.sort(np.array(list(map(int, instances_properties))))
if (len(seg_instances) != len(props_instances)) or any(seg_instances != props_instances):
logger.warning(f"Inconsistent instances {case_npz} from "
f"properties {props_instances} from seg {seg_instances}. "
f"Very small instances can get lost in resampling "
f"but larger instances should not disappear!")
for i in seg_instances:
if str(i) not in instances_properties:
raise RuntimeError(f"Found instance {seg_instances} in segmentation "
f"which is not in properties {instances_properties}."
f"Delete labels manually and rerun prepare label!")
except Exception as e:
logger.error(f"Failed to load {case_npz} with {e}")
logger.error(f"{traceback.format_exc()}")
if remove:
os.remove(case_npz)
if case_pkl is not None:
os.remove(case_pkl)
return case_id, False
return case_id, True
def run(cfg,
num_processes: int,
num_processes_preprocessing: int,
):
"""
Python interface for script
Args:
cfg: dict with config
instances_from_seg: convert semantic segmentation to instance segmentation
"""
logger.remove()
logger.add(sys.stdout, level="INFO")
logger.add(Path(cfg["host"]["data_dir"]) / "logging.log", level="DEBUG")
data_info = cfg["data"]
if cfg["prep"]["crop"]:
# crop data to nonzero area
run_cropping_and_convert(cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
data_info=data_info,
overwrite=cfg["prep"]["overwrite"],
num_processes=num_processes,
)
if cfg["prep"]["analyze"]:
# compute statistics over data and segmentation(e.g. physical volume of individual classes)
run_dataset_analysis(cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
data_info=data_info,
num_processes=num_processes,
intensity_properties=True,
overwrite=cfg["prep"]["overwrite"],
)
if cfg["prep"]["plan"] or cfg["prep"]["process"]:
# plan future training
run_planning_and_process(
splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
planner_name=cfg["planner"],
dim=data_info["dim"],
model_name=cfg["module"],
model_cfg=cfg["model_cfg"],
num_processes=num_processes_preprocessing,
run_preprocessing=cfg["prep"]["process"],
)
@env_guard
def main():
parser = argparse.ArgumentParser()
parser.add_argument('tasks', type=str, nargs='+',
help="Single or multiple task identifiers to process consecutively",
)
parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file", default=[],
required=False)
parser.add_argument('--full_check',
help="Run a full check of the data.",
action='store_true',
)
parser.add_argument('--no_check',
help="Skip basic check.",
action='store_true',
)
parser.add_argument('-np', '--num_processes',
type=int, default=4, required=False,
help="Number of processes to use for croppping.",
)
parser.add_argument('-npp', '--num_processes_preprocessing',
type=int, default=3, required=False,
help="Number of processes to use for resampling.",
)
args = parser.parse_args()
tasks = args.tasks
ov = args.overwrites
full_check = args.full_check
no_check = args.no_check
num_processes = args.num_processes
num_processes_preprocessing = args.num_processes_preprocessing
initialize_config_module(config_module="nndet.conf")
# perform preprocessing checks first
if not no_check:
for task in tasks:
_ov = copy.deepcopy(ov) if ov is not None else []
cfg = compose(task, "config.yaml", overrides=_ov)
check_dataset_file(cfg["task"])
check_data_and_label_splitted(
cfg["task"],
test=False,
labels=True,
full_check=full_check,
)
if cfg["data"]["test_labels"]:
check_data_and_label_splitted(
cfg["task"],
test=True,
labels=True,
full_check=full_check,
)
# start preprocessing
for task in tasks:
_ov = copy.deepcopy(ov) if ov is not None else []
cfg = compose(task, "config.yaml", overrides=_ov)
run(OmegaConf.to_container(cfg, resolve=True),
num_processes=num_processes,
num_processes_preprocessing=num_processes_preprocessing,
)
if __name__ == '__main__':
main()
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import sys
import socket
import argparse
from pathlib import Path
from datetime import datetime
from typing import List
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from loguru import logger
from hydra.experimental import initialize_config_module
from omegaconf.omegaconf import OmegaConf
import nndet
from nndet.utils.config import compose, load_dataset_info
from nndet.utils.info import log_git, write_requirements_to_file, \
create_debug_plan, flatten_mapping
from nndet.utils.check import env_guard
from nndet.utils.analysis import run_analysis_suite
from nndet.io.datamodule.bg_module import Datamodule
from nndet.io.paths import get_task, get_training_dir
from nndet.io.load import load_pickle, save_json, save_pickle
from nndet.evaluator.registry import save_metric_output, evaluate_box_dir, \
evaluate_case_dir, evaluate_seg_dir
from nndet.inference.ensembler.base import extract_results
from nndet.ptmodule import MODULE_REGISTRY
@env_guard
def train():
"""
Training entry
"""
parser = argparse.ArgumentParser()
parser.add_argument('task', type=str,
help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file",
required=False)
parser.add_argument('--sweep',
help="Run empirical parameter optimization",
action='store_true',
)
args = parser.parse_args()
task = args.task
ov = args.overwrites
do_sweep = args.sweep
_train(
task=task,
ov=ov,
do_sweep=do_sweep,
)
@env_guard
def sweep():
"""
Sweep entry
"""
parser = argparse.ArgumentParser()
parser.add_argument('task', type=str,
help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
parser.add_argument('model', type=str,
help="full name of experiment to sweep e.g. RetinaUNetV0_D3V001_3d")
parser.add_argument('fold', type=int,
help="experiment fold")
args = parser.parse_args()
task = args.task
model = args.model
fold = args.fold
_sweep(
task=task,
model=model,
fold=fold,
)
@env_guard
def evaluate():
"""
Evaluation entry
seg, instances are not supported yet
"""
parser = argparse.ArgumentParser()
parser.add_argument('task', type=str, help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
parser.add_argument('model', type=str, help="model name, e.g. RetinaUNetV0_D3V001_3d")
parser.add_argument('fold', type=int, help="fold, -1 => consolidated")
parser.add_argument('--test',
help="Evaluate test predictions -> uses different folder",
action='store_true')
parser.add_argument('--case', help="Run Case Evaluation", action='store_true')
parser.add_argument('--boxes', help="Run Box Evaluation", action='store_true')
parser.add_argument('--seg', help="Run Box Evaluation", action='store_true')
parser.add_argument('--instances', help="Run Box Evaluation", action='store_true')
parser.add_argument('--analyze_boxes', help="Run Box Evaluation", action='store_true')
args = parser.parse_args()
model = args.model
fold = args.fold
task = args.task
test = args.test
do_boxes_eval = args.boxes
do_case_eval = args.case
do_seg_eval = args.seg
do_instances_eval = args.instances
do_analyze_boxes = args.analyze_boxes
_evaluate(
task=task,
model=model,
fold=fold,
test=test,
do_boxes_eval=do_boxes_eval,
do_case_eval=do_case_eval,
do_seg_eval=do_seg_eval,
do_instances_eval=do_instances_eval,
do_analyze_boxes=do_analyze_boxes,
)
def init_train_dir(cfg) -> Path:
"""
Initialize training directory and make it the current working directory
"""
# determine folder for experiment
output_dir = Path(cfg.host.parent_results) / str(cfg.task) / str(cfg.exp.id) / f"fold{cfg.exp.fold}"
if cfg["train"]["mode"].lower() == "overwrite":
if output_dir.is_dir():
print(f"Found existing folder {output_dir}, this run will overwrite "
f"the results inside that folder")
output_dir.mkdir(parents=True, exist_ok=True)
else:
if not output_dir.is_dir():
raise ValueError(f"{output_dir} is not a valid training dir and thus can not be resumed")
os.chdir(str(output_dir))
return output_dir
def _train(
task: str,
ov: List[str],
do_sweep: bool,
):
"""
Run training
Args:
task: task to run training for
ov: overwrites for config manager
do_sweep: determine best emprical parameters for run
"""
print(f"Overwrites: {ov}")
initialize_config_module(config_module="nndet.conf")
cfg = compose(task, "config.yaml", overrides=ov if ov is not None else [])
assert cfg.host.parent_data is not None, 'Parent data can not be None'
assert cfg.host.parent_results is not None, 'Output dir can not be None'
train_dir = init_train_dir(cfg)
pl_logger = MLFlowLogger(
experiment_name=cfg["task"],
tags={
"host": socket.gethostname(),
"fold": cfg["exp"]["fold"],
"task": cfg["task"],
"job_id": os.getenv('LSB_JOBID', 'no_id'),
"mlflow.runName": cfg["exp"]["id"],
},
save_dir=os.getenv("MLFLOW_TRACKING_URI", "./mlruns"),
)
pl_logger.log_hyperparams(flatten_mapping(
{"model": OmegaConf.to_container(cfg["model_cfg"], resolve=True)}))
pl_logger.log_hyperparams(flatten_mapping(
{"trainer": OmegaConf.to_container(cfg["trainer_cfg"], resolve=True)}))
logger.remove()
logger.add(sys.stdout, format="{level} {message}", level="INFO")
log_file = Path(os.getcwd()) / "train.log"
logger.add(log_file, level="INFO")
logger.info(f"Log file at {log_file}")
meta_data = {}
meta_data["torch_version"] = str(torch.__version__)
meta_data["date"] = str(datetime.now())
meta_data["git"] = log_git(nndet.__path__[0], repo_name="nndet")
save_json(meta_data, "./meta.json")
try:
write_requirements_to_file("requirements.txt")
except Exception as e:
logger.error(f"Could not log req: {e}")
plan_path = Path(str(cfg.host["plan_path"]))
plan = load_pickle(plan_path)
save_json(create_debug_plan(plan), "./plan_debug.json")
data_dir = Path(cfg.host["preprocessed_output_dir"]) / plan["data_identifier"] / "imagesTr"
datamodule = Datamodule(
augment_cfg=OmegaConf.to_container(cfg["augment_cfg"], resolve=True),
plan=plan,
data_dir=data_dir,
fold=cfg["exp"]["fold"],
)
module = MODULE_REGISTRY[cfg["module"]](
model_cfg=OmegaConf.to_container(cfg["model_cfg"], resolve=True),
trainer_cfg=OmegaConf.to_container(cfg["trainer_cfg"], resolve=True),
plan=plan,
)
callbacks = []
checkpoint_cb = ModelCheckpoint(
dirpath=train_dir,
filename='model_best',
save_last=True,
save_top_k=1,
monitor=cfg["trainer_cfg"]["monitor_key"],
mode=cfg["trainer_cfg"]["monitor_mode"],
)
checkpoint_cb.CHECKPOINT_NAME_LAST = 'model_last'
callbacks.append(checkpoint_cb)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
OmegaConf.save(cfg, str(Path(os.getcwd()) / "config.yaml"))
OmegaConf.save(cfg, str(Path(os.getcwd()) / "config_resolved.yaml"), resolve=True)
save_pickle(plan, train_dir / "plan.pkl") # backup plan
splits = load_pickle(Path(cfg.host.preprocessed_output_dir) / datamodule.splits_file)
save_pickle(splits, train_dir / "splits.pkl")
trainer_kwargs = {}
if cfg["train"]["mode"].lower() == "resume":
trainer_kwargs["resume_from_checkpoint"] = train_dir / "model_last.ckpt"
num_gpus = cfg["trainer_cfg"]["gpus"]
logger.info(f"Using {num_gpus} GPUs for training")
plugins = cfg["trainer_cfg"].get("plugins", None)
logger.info(f"Using {plugins} plugins for training")
trainer = pl.Trainer(
gpus=list(range(num_gpus)) if num_gpus > 1 else num_gpus,
accelerator=cfg["trainer_cfg"]["accelerator"],
precision=cfg["trainer_cfg"]["precision"],
amp_backend=cfg["trainer_cfg"]["amp_backend"],
amp_level=cfg["trainer_cfg"]["amp_level"],
benchmark=cfg["trainer_cfg"]["benchmark"],
deterministic=cfg["trainer_cfg"]["deterministic"],
callbacks=callbacks,
logger=pl_logger,
max_epochs=module.max_epochs,
progress_bar_refresh_rate=None if bool(int(os.getenv("det_verbose", 1))) else 0,
reload_dataloaders_every_epoch=False,
num_sanity_val_steps=10,
weights_summary='full',
plugins=plugins,
terminate_on_nan=True, # TODO: make modular
move_metrics_to_cpu=True,
**trainer_kwargs
)
trainer.fit(module, datamodule=datamodule)
if do_sweep:
case_ids = splits[cfg["exp"]["fold"]]["val"]
if "debug" in cfg and "num_cases_val" in cfg["debug"]:
case_ids = case_ids[:cfg["debug"]["num_cases_val"]]
inference_plan = module.sweep(
cfg=OmegaConf.to_container(cfg, resolve=True),
save_dir=train_dir,
train_data_dir=data_dir,
case_ids=case_ids,
run_prediction=True,
)
plan["inference_plan"] = inference_plan
save_pickle(plan, train_dir / "plan_inference.pkl")
ensembler_cls = module.get_ensembler_cls(
key="boxes", dim=plan["network_dim"]) # TODO: make this configurable
for restore in [True, False]:
target_dir = train_dir / "val_predictions" if restore else \
train_dir / "val_predictions_preprocessed"
extract_results(source_dir=train_dir / "sweep_predictions",
target_dir=target_dir,
ensembler_cls=ensembler_cls,
restore=restore,
**inference_plan,
)
_evaluate(
task=cfg["task"],
model=cfg["exp"]["id"],
fold=cfg["exp"]["fold"],
test=False,
do_boxes_eval=True, # TODO: make this configurable
do_analyze_boxes=True, # TODO: make this configurable
)
def _sweep(
task: str,
model: str,
fold: int,
):
"""
Determine best postprocessing parameters for a trained model
Args:
task: current task
model: full name of the model run determine empricial parameters for
e.g. RetinaUNetV001_D3V001_3d
fold: current fold
"""
nndet_data_dir = Path(os.getenv("det_models"))
task = get_task(task, name=True, models=True)
train_dir = nndet_data_dir / task / model / f"fold{fold}"
cfg = OmegaConf.load(str(train_dir / "config.yaml"))
os.chdir(str(train_dir))
logger.remove()
logger.add(sys.stdout, format="{level} {message}", level="INFO")
log_file = Path(os.getcwd()) / "sweep.log"
logger.add(log_file, level="INFO")
logger.info(f"Log file at {log_file}")
plan = load_pickle(train_dir / "plan.pkl")
data_dir = Path(cfg.host["preprocessed_output_dir"]) / plan["data_identifier"] / "imagesTr"
module = MODULE_REGISTRY[cfg["module"]](
model_cfg=OmegaConf.to_container(cfg["model_cfg"], resolve=True),
trainer_cfg=OmegaConf.to_container(cfg["trainer_cfg"], resolve=True),
plan=plan,
)
splits = load_pickle(train_dir / "splits.pkl")
case_ids = splits[cfg["exp"]["fold"]]["val"]
inference_plan = module.sweep(
cfg=OmegaConf.to_container(cfg, resolve=True),
save_dir=train_dir,
train_data_dir=data_dir,
case_ids=case_ids,
run_prediction=True, # TODO: add commmand line arg
)
plan["inference_plan"] = inference_plan
save_pickle(plan, train_dir / "plan_inference.pkl")
ensembler_cls = module.get_ensembler_cls(
key="boxes", dim=plan["network_dim"]) # TODO: make this configurable
for restore in [True, False]:
target_dir = train_dir / "val_predictions" if restore else \
train_dir / "val_predictions_preprocessed"
extract_results(source_dir=train_dir / "sweep_predictions",
target_dir=target_dir,
ensembler_cls=ensembler_cls,
restore=restore,
**inference_plan,
)
_evaluate(
task=cfg["task"],
model=cfg["exp"]["id"],
fold=cfg["exp"]["fold"],
test=False,
do_boxes_eval=True, # TODO: make this configurable
do_analyze_boxes=True, # TODO: make this configurable
)
def _evaluate(
task: str,
model: str,
fold: int,
test: bool = False,
do_case_eval: bool = False,
do_boxes_eval: bool = False,
do_seg_eval: bool = False,
do_instances_eval: bool = False,
do_analyze_boxes: bool = False,
):
"""
This entrypoint runs the evaluation
Args:
task: current task
model: full name of the model run determine empricial parameters for
e.g. RetinaUNetV001_D3V001_3d
fold: current fold
test: use test split
do_case_eval: evaluate patient metrics
do_boxes_eval: perform box evaluation
do_seg_eval: perform semantic segmentation evaluation
do_instances_eval: perform instance segmentation evaluation
do_analyze_boxes: run analysis of box results
"""
# prepare paths
task = get_task(task, name=True)
model_dir = Path(os.getenv("det_models")) / task / model
training_dir = get_training_dir(model_dir, fold)
data_dir_task = Path(os.getenv("det_data")) / task
data_cfg = load_dataset_info(data_dir_task)
prefix = "test" if test else "val"
modes = [True] if test else [True, False]
for restore in modes:
if restore:
pred_dir_name = f"{prefix}_predictions"
gt_dir_name = "labelsTs" if test else "labelsTr"
gt_dir = data_dir_task / "preprocessed" / gt_dir_name
else:
plan = load_pickle(training_dir / "plan.pkl")
pred_dir_name = f"{prefix}_predictions_preprocessed"
gt_dir = data_dir_task / "preprocessed" / plan["data_identifier"] / "labelsTr"
pred_dir = training_dir / pred_dir_name
save_dir = training_dir / f"{prefix}_results" if restore else \
training_dir / f"{prefix}_results_preprocessed"
# compute metrics
if do_boxes_eval:
logger.info(f"Computing box metrics: restore {restore}")
scores, curves = evaluate_box_dir(
pred_dir=pred_dir,
gt_dir=gt_dir,
classes=list(data_cfg["labels"].keys()),
save_dir=save_dir / "boxes",
)
save_metric_output(scores, curves, save_dir, "results_boxes")
if do_case_eval:
logger.info(f"Computing case metrics: restore {restore}")
scores, curves = evaluate_case_dir(
pred_dir=pred_dir,
gt_dir=gt_dir,
classes=list(data_cfg["labels"].keys()),
target_class=data_cfg["target_class"],
)
save_metric_output(scores, curves, save_dir, "results_case")
if do_seg_eval:
logger.info(f"Computing seg metrics: restore {restore}")
scores, curves = evaluate_seg_dir(
pred_dir=pred_dir,
gt_dir=gt_dir,
)
save_metric_output(scores, curves, save_dir, "results_seg")
if do_instances_eval:
raise NotImplementedError
# run analysis
save_dir = training_dir / f"{prefix}_analysis" if restore else \
training_dir / f"{prefix}_analysis_preprocessed"
if do_analyze_boxes:
logger.info(f"Analyze box predictions: restore {restore}")
run_analysis_suite(prediction_dir=pred_dir,
gt_dir=gt_dir,
save_dir=save_dir / "boxes",
)
if __name__ == "__main__":
train()
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
def boxes2nii():
import os
import argparse
from pathlib import Path
import numpy as np
import SimpleITK as sitk
from loguru import logger
from nndet.io import save_json, load_pickle
from nndet.io.paths import get_task, get_training_dir
from nndet.utils.info import maybe_verbose_iterable
parser = argparse.ArgumentParser()
parser.add_argument('task', type=str, help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
parser.add_argument('model', type=str, help="model name, e.g. RetinaUNetV0")
parser.add_argument('-f', '--fold', type=int, help="fold to sweep.", default=0, required=False)
parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file",
required=False)
parser.add_argument('--threshold',
type=float,
help="Minimum probability of predictions",
required=False,
default=0.5,
)
parser.add_argument('--test', action='store_true')
args = parser.parse_args()
model = args.model
fold = args.fold
task = args.task
overwrites = args.overwrites
test = args.test
threshold = args.threshold
task_name = get_task(task, name=True, models=True)
task_dir = Path(os.getenv("det_models")) / task_name
training_dir = get_training_dir(task_dir / model, fold)
overwrites = overwrites if overwrites is not None else []
overwrites.append("host.parent_data=${env:det_data}")
overwrites.append("host.parent_results=${env:det_models}")
prediction_dir = training_dir / "test_predictions" \
if test else training_dir / "val_predictions"
save_dir = training_dir / "test_predictions_nii" \
if test else training_dir / "val_predictions_nii"
save_dir.mkdir(exist_ok=True)
case_ids = [p.stem.rsplit('_', 1)[0] for p in prediction_dir.glob("*_boxes.pkl")]
for cid in maybe_verbose_iterable(case_ids):
res = load_pickle(prediction_dir / f"{cid}_boxes.pkl")
instance_mask = np.zeros(res["original_size_of_raw_data"], dtype=np.uint8)
boxes = res["pred_boxes"]
scores = res["pred_scores"]
labels = res["pred_labels"]
_mask = scores >= threshold
boxes = boxes[_mask]
labels = labels[_mask]
scores = scores[_mask]
idx = np.argsort(scores)
scores = scores[idx]
boxes = boxes[idx]
labels = labels[idx]
prediction_meta = {}
for instance_id, (pbox, pscore, plabel) in enumerate(zip(boxes, scores, labels), start=1):
mask_slicing = [slice(int(pbox[0]), int(pbox[2])),
slice(int(pbox[1]), int(pbox[3])),
]
if instance_mask.ndim == 3:
mask_slicing.append(slice(int(pbox[4]), int(pbox[5])))
instance_mask[tuple(mask_slicing)] = instance_id
prediction_meta[int(instance_id)] = {
"score": float(pscore),
"label": int(plabel),
"box": list(map(int, pbox))
}
logger.info(f"Created instance mask with {instance_mask.max()} instances.")
instance_mask_itk = sitk.GetImageFromArray(instance_mask)
instance_mask_itk.SetOrigin(res["itk_origin"])
instance_mask_itk.SetDirection(res["itk_direction"])
instance_mask_itk.SetSpacing(res["itk_spacing"])
sitk.WriteImage(instance_mask_itk, str(save_dir / f"{cid}_boxes.nii.gz"))
save_json(prediction_meta, save_dir / f"{cid}_boxes.json")
def seg2nii():
import os
import argparse
from pathlib import Path
import SimpleITK as sitk
from nndet.io import load_pickle
from nndet.io.paths import get_task, get_training_dir
from nndet.utils.info import maybe_verbose_iterable
parser = argparse.ArgumentParser()
parser.add_argument('task', type=str, help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
parser.add_argument('model', type=str, help="model name, e.g. RetinaUNetV0")
parser.add_argument('-f', '--fold', type=int, help="fold to sweep.", default=0, required=False)
parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file",
required=False)
parser.add_argument('--test', action='store_true')
args = parser.parse_args()
model = args.model
fold = args.fold
task = args.task
overwrites = args.overwrites
test = args.test
task_name = get_task(task, name=True, models=True)
task_dir = Path(os.getenv("det_models")) / task_name
training_dir = get_training_dir(task_dir / model, fold)
overwrites = overwrites if overwrites is not None else []
overwrites.append("host.parent_data=${env:det_data}")
overwrites.append("host.parent_results=${env:det_models}")
prediction_dir = training_dir / "test_predictions" \
if test else training_dir / "val_predictions"
save_dir = training_dir / "test_predictions_nii" \
if test else training_dir / "val_predictions_nii"
save_dir.mkdir(exist_ok=True)
case_ids = [p.stem.rsplit('_', 1)[0] for p in prediction_dir.glob("*_seg.pkl")]
for cid in maybe_verbose_iterable(case_ids):
res = load_pickle(prediction_dir / f"{cid}_seg.pkl")
seg_itk = sitk.GetImageFromArray(res["pred_seg"])
seg_itk.SetOrigin(res["itk_origin"])
seg_itk.SetDirection(res["itk_direction"])
seg_itk.SetSpacing(res["itk_spacing"])
sitk.WriteImage(seg_itk, str(save_dir / f"{cid}_seg.nii.gz"))
def unpack():
import argparse
from pathlib import Path
from nndet.io.load import unpack_dataset
parser = argparse.ArgumentParser()
parser.add_argument('path', type=Path, help="Path to folder to unpack")
parser.add_argument('num_processes', type=int, help="number of processes to use for unpacking")
args = parser.parse_args()
p = args.path
num_processes = args.num_processes
unpack_dataset(p, num_processes, False)
def env():
import os
import torch
import sys
print(f"PyTorch Version: {torch.version}")
print(f"PyTorch CUDA: {torch.version.cuda}")
print(f"PyTorch Backend cudnn: {torch.backends.cudnn.version()}")
print(f"PyTorch CUDA Arch List: {torch.cuda.get_arch_list()}")
print(f"PyTorch Current Device Capability: {torch.cuda.get_device_capability()}")
print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
stream = os.popen('nvcc --version')
output = stream.read()
print(f"System NVCC: {output}")
print(f"System Arch List: {os.getenv('TORCH_CUDA_ARCH_LIST', None)}")
print(f"System OMP_NUM_THREADS: {os.getenv('OMP_NUM_THREADS', None)}")
print(f"System CUDA_HOME is None: {os.getenv('CUDA_HOME', None) is None}")
print(f"Python Version: {sys.version}")
if __name__ == '__main__':
env()
[pycodestyle]
exclude = .eggs,*.egg,build,docs/*,.git,*/conf.py
ignore = E402, E721
max_line_length = 120
[coverage:run]
source = nndet
omit =
*__init__.py
*registry.py
tests/*
from setuptools import setup, find_packages
from pathlib import Path
import os
import sys
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
def resolve_requirements(file):
requirements = []
with open(file) as f:
req = f.read().splitlines()
for r in req:
if r.startswith("-r"):
requirements += resolve_requirements(
os.path.join(os.path.dirname(file), r.split(" ")[1]))
else:
requirements.append(r)
return requirements
def read_file(file):
with open(file) as f:
content = f.read()
return content
def clean():
"""Custom clean command to tidy up the project root."""
os.system('rm -vrf ./build ./dist ./*.pyc ./*.tgz')
def get_extensions():
"""
Adapted from https://github.com/pytorch/vision/blob/master/setup.py
and https://github.com/facebookresearch/detectron2/blob/master/setup.py
"""
print("Build csrc")
print("Building with {}".format(sys.version_info))
this_dir = Path(os.path.dirname(os.path.abspath(__file__)))
extensions_dir = this_dir/'nndet'/'csrc'
main_file = list(extensions_dir.glob('*.cpp'))
source_cpu = [] # list((extensions_dir/'cpu').glob('*.cpp')) temporary until I added header files ...
source_cuda = list((extensions_dir/'cuda').glob('*.cu'))
print("main_file {}".format(main_file))
print("source_cpu {}".format(source_cpu))
print("source_cuda {}".format(source_cuda))
sources = main_file + source_cpu
extension = CppExtension
define_macros = []
extra_compile_args = {"cxx": []}
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
print("Adding CUDA csrc to build")
print("CUDA ARCH {}".format(os.getenv("TORCH_CUDA_ARCH_LIST")))
extension = CUDAExtension
sources += source_cuda
define_macros += [('WITH_CUDA', None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
# It's better if pytorch can do this by default ..
CC = os.environ.get("CC", None)
if CC is not None:
extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [str(extensions_dir)]
ext_modules = [
extension(
'nndet._C',
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
requirements = resolve_requirements(os.path.join(os.path.dirname(__file__),
'requirements.txt'))
readme = read_file(os.path.join(os.path.dirname(__file__), "README.md"))
setup(
name='nndet',
version="v0.1",
packages=find_packages(),
include_package_data=True,
test_suite="unittest",
long_description=readme,
long_description_content_type='text/markdown',
install_requires=requirements,
tests_require=["coverage"],
python_requires=">=3.8",
author="Division of Medical Image Computing, German Cancer Research Center",
maintainer_email='m.baumgartner@dkfz-heidelberg.de',
ext_modules=get_extensions(),
cmdclass={
'build_ext': BuildExtension,
'clean': clean,
},
entry_points={
'console_scripts': [
'nndet_example = scripts.generate_example:main',
'nndet_prep = scripts.preprocess:main',
'nndet_cls2fg = scripts.convert_cls2fg:main',
'nndet_seg2det = scripts.convert_seg2det:main',
'nndet_train = scripts.train:train',
'nndet_sweep = scripts.train:sweep',
'nndet_eval = scripts.train:evaluate',
'nndet_predict = scripts.predict:main',
'nndet_consolidate = scripts.consolidate:main',
'nndet_boxes2nii = scripts.utils:boxes2nii',
'nndet_seg2nii = scripts.utils:seg2nii',
'nndet_unpack = scripts.utils:unpack',
'nndet_env = scripts.utils:env',
]
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment