Commit 2a8e54b4 authored by mibaumgartner's avatar mibaumgartner
Browse files

add checkpoint identifier for ensembling

parent b32e7c19
......@@ -32,17 +32,18 @@ from nndet.inference.ensembler.base import extract_results
from nndet.io import get_task, load_pickle, save_pickle
def consolidate_models(source_dirs: Sequence[Path], target_dir: Path):
def consolidate_models(source_dirs: Sequence[Path], target_dir: Path, ckpt: str):
"""
Copy final models from folds into consolidated folder
Args:
source_dirs: directory of each fold to consolidate
target_dir: directory to save models to
ckpt: checkpoint identifier to select models for ensembling
"""
for fold, sd in enumerate(source_dirs):
model_paths = list(sd.glob('*.ckpt'))
found_models = [mp for mp in model_paths if "last" in str(mp.stem)]
found_models = [mp for mp in model_paths if ckpt in str(mp.stem)]
assert len(found_models) == 1, f"Found wrong number of models, {found_models}"
model_path = found_models[0]
assert f"fold{fold}" in str(model_path.parent.stem), f"Expected fold {fold} but found {model_path}"
......@@ -108,6 +109,9 @@ def main():
parser.add_argument('--sweep_instances', action="store_true",
help="Sweep for best parameters for instance segmentation based models",
)
parser.add_argument('--ckpt', type=str, default="last", required=False,
help="Define identifier of checkpoint for consolidation. "
"Use this with care!")
args = parser.parse_args()
model = args.model
......@@ -120,6 +124,7 @@ def main():
sweep_boxes = args.sweep_boxes
sweep_instances = args.sweep_instances
ckpt = args.ckpt
if consolidate == "export" and not (sweep_boxes or sweep_instances):
raise ValueError("Export needs new parameter sweep! Actiate one of the sweep "
......@@ -142,7 +147,10 @@ def main():
# model consolidation
if do_model_consolidation:
logger.info("Consolidate models")
consolidate_models(training_dirs, target_dir)
if ckpt != "last":
logger.warning(f"Found ckpt overwrite {ckpt}, this is not the default, "
"this can drastically influence the performance!")
consolidate_models(training_dirs, target_dir, ckpt)
# consolidate predictions
logger.info("Consolidate predictions")
......
......@@ -81,7 +81,7 @@ def run(cfg: dict,
else:
source_dir = preprocessed_output_dir / plan["data_identifier"] / "imagesTs"
case_ids = None
predict_dir(source_dir=source_dir,
target_dir=prediction_dir,
cfg=cfg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment