Commit 724ba674 authored by mibaumgartner's avatar mibaumgartner
Browse files

minor improvements in predict script

parent 0f3a95f3
...@@ -36,6 +36,7 @@ def run(cfg: dict, ...@@ -36,6 +36,7 @@ def run(cfg: dict,
process: bool = True, process: bool = True,
num_models: int = None, num_models: int = None,
num_tta_transforms: int = None, num_tta_transforms: int = None,
test_split: bool = False,
): ):
""" """
Run inference pipeline Run inference pipeline
...@@ -48,6 +49,10 @@ def run(cfg: dict, ...@@ -48,6 +49,10 @@ def run(cfg: dict,
are used are used
num_tta_transforms: number of tta transformation; if None the maximum num_tta_transforms: number of tta transformation; if None the maximum
number of transformation is used number of transformation is used
test_split: Typical usage of nnDetection will never require
this option! Predict an already preprocessed split of the original
training data. The 'test' split needs to be located in fold 0
of a manually created split file.
""" """
plan = load_pickle(training_dir / "plan_inference.pkl") plan = load_pickle(training_dir / "plan_inference.pkl")
...@@ -68,7 +73,13 @@ def run(cfg: dict, ...@@ -68,7 +73,13 @@ def run(cfg: dict,
) )
prediction_dir.mkdir(parents=True, exist_ok=True) prediction_dir.mkdir(parents=True, exist_ok=True)
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTs" if test_split:
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTr"
case_ids = load_pickle(training_dir / "splits.pkl")[0]["test"]
else:
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTs"
case_ids = None
predict_dir(source_dir=source_dir, predict_dir(source_dir=source_dir,
target_dir=prediction_dir, target_dir=prediction_dir,
cfg=cfg, cfg=cfg,
...@@ -78,7 +89,8 @@ def run(cfg: dict, ...@@ -78,7 +89,8 @@ def run(cfg: dict,
num_tta_transforms=num_tta_transforms, num_tta_transforms=num_tta_transforms,
model_fn=load_all_models, model_fn=load_all_models,
restore=True, restore=True,
# do_seg=True, # TODO: change this... case_ids=case_ids,
**cfg.get("inference_kwargs", {}),
) )
...@@ -129,7 +141,10 @@ def main(): ...@@ -129,7 +141,10 @@ def main():
help="number of tta transforms (per default most tta are chosen)", help="number of tta transforms (per default most tta are chosen)",
required=False) required=False)
parser.add_argument('-o', '--overwrites', type=str, nargs='+', parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file", default=None, help=("overwrites for config file. "
"inference_kwargs can be used to add additional "
"keyword arguments to inference."),
default=None,
required=False) required=False)
parser.add_argument('--no_preprocess', help="Preprocess test data", action='store_false') parser.add_argument('--no_preprocess', help="Preprocess test data", action='store_false')
parser.add_argument('--force_args', parser.add_argument('--force_args',
...@@ -137,6 +152,12 @@ def main(): ...@@ -137,6 +152,12 @@ def main():
"and fold might differ from the original one. " "and fold might differ from the original one. "
"This forces an overwrite to the passed in arguments of" "This forces an overwrite to the passed in arguments of"
" this function. This can be dangerous!"), action='store_true') " this function. This can be dangerous!"), action='store_true')
parser.add_argument('--test_split',
help=("Typical usage of nnDetection will never require "
"this option! Predict an already preprocessed "
"split of the original training data. "
"The 'test' split needs to be located in fold 0 "
"of a manually created split file."))
args = parser.parse_args() args = parser.parse_args()
model = args.model model = args.model
...@@ -146,12 +167,16 @@ def main(): ...@@ -146,12 +167,16 @@ def main():
num_tta_transforms = args.num_tta num_tta_transforms = args.num_tta
ov = args.overwrites ov = args.overwrites
force_args = args.force_args force_args = args.force_args
test_split = args.test_split
task_name = get_task(task, name=True) task_name = get_task(task, name=True)
task_model_dir = Path(os.getenv("det_models")) task_model_dir = Path(os.getenv("det_models"))
training_dir = get_training_dir(task_model_dir / task_name / model, fold) training_dir = get_training_dir(task_model_dir / task_name / model, fold)
process = args.no_preprocess process = args.no_preprocess
if test_split:
raise ValueError("When using the test split option raw data is not "
"supported. Need to add --no_preprocess flag!")
cfg = OmegaConf.load(str(training_dir / "config.yaml")) cfg = OmegaConf.load(str(training_dir / "config.yaml"))
...@@ -170,6 +195,7 @@ def main(): ...@@ -170,6 +195,7 @@ def main():
process=process, process=process,
num_models=num_models, num_models=num_models,
num_tta_transforms=num_tta_transforms, num_tta_transforms=num_tta_transforms,
test_split=test_split,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment