Commit 724ba674 authored by mibaumgartner's avatar mibaumgartner
Browse files

minor improvements in predict script

parent 0f3a95f3
......@@ -36,6 +36,7 @@ def run(cfg: dict,
process: bool = True,
num_models: int = None,
num_tta_transforms: int = None,
test_split: bool = False,
):
"""
Run inference pipeline
......@@ -48,6 +49,10 @@ def run(cfg: dict,
are used
num_tta_transforms: number of tta transformation; if None the maximum
number of transformation is used
test_split: Typical usage of nnDetection will never require
this option! Predict an already preprocessed split of the original
training data. The 'test' split needs to be located in fold 0
of a manually created split file.
"""
plan = load_pickle(training_dir / "plan_inference.pkl")
......@@ -68,7 +73,13 @@ def run(cfg: dict,
)
prediction_dir.mkdir(parents=True, exist_ok=True)
if test_split:
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTr"
case_ids = load_pickle(training_dir / "splits.pkl")[0]["test"]
else:
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTs"
case_ids = None
predict_dir(source_dir=source_dir,
target_dir=prediction_dir,
cfg=cfg,
......@@ -78,7 +89,8 @@ def run(cfg: dict,
num_tta_transforms=num_tta_transforms,
model_fn=load_all_models,
restore=True,
# do_seg=True, # TODO: change this...
case_ids=case_ids,
**cfg.get("inference_kwargs", {}),
)
......@@ -129,7 +141,10 @@ def main():
help="number of tta transforms (per default most tta are chosen)",
required=False)
parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file", default=None,
help=("overwrites for config file. "
"inference_kwargs can be used to add additional "
"keyword arguments to inference."),
default=None,
required=False)
parser.add_argument('--no_preprocess', help="Preprocess test data", action='store_false')
parser.add_argument('--force_args',
......@@ -137,6 +152,12 @@ def main():
"and fold might differ from the original one. "
"This forces an overwrite to the passed in arguments of"
" this function. This can be dangerous!"), action='store_true')
parser.add_argument('--test_split',
help=("Typical usage of nnDetection will never require "
"this option! Predict an already preprocessed "
"split of the original training data. "
"The 'test' split needs to be located in fold 0 "
"of a manually created split file."))
args = parser.parse_args()
model = args.model
......@@ -146,12 +167,16 @@ def main():
num_tta_transforms = args.num_tta
ov = args.overwrites
force_args = args.force_args
test_split = args.test_split
task_name = get_task(task, name=True)
task_model_dir = Path(os.getenv("det_models"))
training_dir = get_training_dir(task_model_dir / task_name / model, fold)
process = args.no_preprocess
if test_split:
raise ValueError("When using the test split option raw data is not "
"supported. Need to add --no_preprocess flag!")
cfg = OmegaConf.load(str(training_dir / "config.yaml"))
......@@ -170,6 +195,7 @@ def main():
process=process,
num_models=num_models,
num_tta_transforms=num_tta_transforms,
test_split=test_split,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment