Commit 2a8e54b4 authored by mibaumgartner's avatar mibaumgartner
Browse files

add checkpoint identifier for ensembling

parent b32e7c19
...@@ -32,17 +32,18 @@ from nndet.inference.ensembler.base import extract_results ...@@ -32,17 +32,18 @@ from nndet.inference.ensembler.base import extract_results
from nndet.io import get_task, load_pickle, save_pickle from nndet.io import get_task, load_pickle, save_pickle
def consolidate_models(source_dirs: Sequence[Path], target_dir: Path): def consolidate_models(source_dirs: Sequence[Path], target_dir: Path, ckpt: str):
""" """
Copy final models from folds into consolidated folder Copy final models from folds into consolidated folder
Args: Args:
source_dirs: directory of each fold to consolidate source_dirs: directory of each fold to consolidate
target_dir: directory to save models to target_dir: directory to save models to
ckpt: checkpoint identifier to select models for ensembling
""" """
for fold, sd in enumerate(source_dirs): for fold, sd in enumerate(source_dirs):
model_paths = list(sd.glob('*.ckpt')) model_paths = list(sd.glob('*.ckpt'))
found_models = [mp for mp in model_paths if "last" in str(mp.stem)] found_models = [mp for mp in model_paths if ckpt in str(mp.stem)]
assert len(found_models) == 1, f"Found wrong number of models, {found_models}" assert len(found_models) == 1, f"Found wrong number of models, {found_models}"
model_path = found_models[0] model_path = found_models[0]
assert f"fold{fold}" in str(model_path.parent.stem), f"Expected fold {fold} but found {model_path}" assert f"fold{fold}" in str(model_path.parent.stem), f"Expected fold {fold} but found {model_path}"
...@@ -108,6 +109,9 @@ def main(): ...@@ -108,6 +109,9 @@ def main():
parser.add_argument('--sweep_instances', action="store_true", parser.add_argument('--sweep_instances', action="store_true",
help="Sweep for best parameters for instance segmentation based models", help="Sweep for best parameters for instance segmentation based models",
) )
parser.add_argument('--ckpt', type=str, default="last", required=False,
help="Define identifier of checkpoint for consolidation. "
"Use this with care!")
args = parser.parse_args() args = parser.parse_args()
model = args.model model = args.model
...@@ -120,6 +124,7 @@ def main(): ...@@ -120,6 +124,7 @@ def main():
sweep_boxes = args.sweep_boxes sweep_boxes = args.sweep_boxes
sweep_instances = args.sweep_instances sweep_instances = args.sweep_instances
ckpt = args.ckpt
if consolidate == "export" and not (sweep_boxes or sweep_instances): if consolidate == "export" and not (sweep_boxes or sweep_instances):
raise ValueError("Export needs new parameter sweep! Actiate one of the sweep " raise ValueError("Export needs new parameter sweep! Actiate one of the sweep "
...@@ -142,7 +147,10 @@ def main(): ...@@ -142,7 +147,10 @@ def main():
# model consolidation # model consolidation
if do_model_consolidation: if do_model_consolidation:
logger.info("Consolidate models") logger.info("Consolidate models")
consolidate_models(training_dirs, target_dir) if ckpt != "last":
logger.warning(f"Found ckpt overwrite {ckpt}, this is not the default, "
"this can drastically influence the performance!")
consolidate_models(training_dirs, target_dir, ckpt)
# consolidate predictions # consolidate predictions
logger.info("Consolidate predictions") logger.info("Consolidate predictions")
......
...@@ -81,7 +81,7 @@ def run(cfg: dict, ...@@ -81,7 +81,7 @@ def run(cfg: dict,
else: else:
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTs" source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTs"
case_ids = None case_ids = None
predict_dir(source_dir=source_dir, predict_dir(source_dir=source_dir,
target_dir=prediction_dir, target_dir=prediction_dir,
cfg=cfg, cfg=cfg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment