Commit 219a52a8 authored by mibaumgartner's avatar mibaumgartner
Browse files

option to consoliadte and sweep different ckpt

parent 2a8e54b4
...@@ -21,6 +21,11 @@ augment_cfg: ...@@ -21,6 +21,11 @@ augment_cfg:
num_cached_per_thread: 2 num_cached_per_thread: 2
multiprocessing: True # only deactivate this if debugging multiprocessing: True # only deactivate this if debugging
# Additional overwrites
# patch_size; Default: plan
# batch_size; Default: plan
# splits; Default: splits_final
trainer_cfg: trainer_cfg:
gpus: 1 # number of gpus gpus: 1 # number of gpus
accelerator: ddp # distributed backend accelerator: ddp # distributed backend
...@@ -52,6 +57,7 @@ trainer_cfg: ...@@ -52,6 +57,7 @@ trainer_cfg:
poly_gamma: 0.9 poly_gamma: 0.9
swa_epochs: 10 # number of epochs to run swa with cyclic learning rate swa_epochs: 10 # number of epochs to run swa with cyclic learning rate
# sweep_ckpt: Select checkpoint identifier for sweeping. Default "last".
model_cfg: model_cfg:
encoder_kwargs: {} # keyword arguments passed to encoder encoder_kwargs: {} # keyword arguments passed to encoder
......
...@@ -23,7 +23,7 @@ from loguru import logger ...@@ -23,7 +23,7 @@ from loguru import logger
from nndet.utils.tensor import to_numpy from nndet.utils.tensor import to_numpy
from nndet.io.load import load_pickle, save_pickle from nndet.io.load import load_pickle, save_pickle
from nndet.io.paths import Pathlike, get_case_id_from_path from nndet.io.paths import Pathlike, get_case_id_from_path
from nndet.inference.loading import load_time_ensemble from nndet.inference.loading import load_final_model
def predict_dir( def predict_dir(
...@@ -32,7 +32,7 @@ def predict_dir( ...@@ -32,7 +32,7 @@ def predict_dir(
cfg: dict, cfg: dict,
plan: dict, plan: dict,
source_models: Path, source_models: Path,
model_fn: Callable[[Path, dict, dict, int], Sequence[dict]] = load_time_ensemble, model_fn: Callable[[Path, dict, dict, int], Sequence[dict]] = load_final_model,
num_models: int = None, num_models: int = None,
num_tta_transforms: int = None, num_tta_transforms: int = None,
restore: bool = False, restore: bool = False,
......
...@@ -27,14 +27,10 @@ from nndet.io.paths import Pathlike ...@@ -27,14 +27,10 @@ from nndet.io.paths import Pathlike
def get_loader_fn(mode: str, **kwargs): def get_loader_fn(mode: str, **kwargs):
if mode == "best": if mode.lower() == "all":
load_fn = partial(load_time_ensemble, **kwargs) load_fn = load_all_models
elif mode == "final":
load_fn = partial(load_final_model, **kwargs)
elif mode == "latest":
load_fn = partial(load_final_model, identifier="latest", **kwargs)
else: else:
raise ValueError(f"Unknown mode {mode}") load_fn = partial(load_final_model, identifier=mode, **kwargs)
return load_fn return load_fn
...@@ -61,54 +57,6 @@ def get_latest_model(base_dir: Pathlike, fold: int = 0) -> Optional[Path]: ...@@ -61,54 +57,6 @@ def get_latest_model(base_dir: Pathlike, fold: int = 0) -> Optional[Path]:
return None return None
# TODO: update
def load_time_ensemble(
source_models: Path,
cfg: dict,
plan: dict,
num_models: int = None,
) -> Sequence[dict]:
"""
Load time ensembled models
Args:
source_models: path to directory where models are saved
cfg: config used for experiment
`model`: name of model in DETECTION_REGISTRY
plan: plan used for training
num_models: number of models to load
Returns:
Sequence[dict]: loaded models
`model`: loaded model
`rank`: rank of model
"""
logger.info("Loading time ensemble")
model_names = list(source_models.glob('model_best*.ckpt'))
if not model_names:
raise RuntimeError(f"Did not find any models in {source_models}")
models = []
for path in model_names:
model = MODULE_REGISTRY[cfg["module"]](
model_cfg=cfg["model_cfg"],
trainer_cfg=cfg["trainer_cfg"],
plan=plan,
)
state_dict = torch.load(path, map_location="cpu")["state_dict"]
t = model.load_state_dict(state_dict)
logger.info(f"Loaded {path} with {t}")
model.float()
model.eval()
rank = int(str(path).rsplit(os.sep, 1)[-1][10])
models.append({"model": model.cpu(), "rank": rank})
if num_models is not None:
models = models[:num_models]
logger.info(f"Using {len(models)} models for for inference.")
return models
def load_final_model( def load_final_model(
source_models: Path, source_models: Path,
cfg: dict, cfg: dict,
......
...@@ -58,7 +58,7 @@ from nndet.training.learning_rate import LinearWarmupPolyLR ...@@ -58,7 +58,7 @@ from nndet.training.learning_rate import LinearWarmupPolyLR
from nndet.inference.predictor import Predictor from nndet.inference.predictor import Predictor
from nndet.inference.sweeper import BoxSweeper from nndet.inference.sweeper import BoxSweeper
from nndet.inference.transforms import get_tta_transforms, Inference2D from nndet.inference.transforms import get_tta_transforms, Inference2D
from nndet.inference.loading import load_final_model from nndet.inference.loading import get_loader_fn
from nndet.inference.helper import predict_dir from nndet.inference.helper import predict_dir
from nndet.inference.ensembler.segmentation import SegmentationEnsembler from nndet.inference.ensembler.segmentation import SegmentationEnsembler
from nndet.inference.ensembler.detection import BoxEnsemblerSelective from nndet.inference.ensembler.detection import BoxEnsemblerSelective
...@@ -762,7 +762,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -762,7 +762,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
num_tta_transforms=None, num_tta_transforms=None,
case_ids=case_ids, case_ids=case_ids,
save_state=True, save_state=True,
model_fn=load_final_model, model_fn=get_loader_fn(mode=self.trainer_cfg.get("sweep_ckpt", "last")),
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment