Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
219a52a8
Commit
219a52a8
authored
May 30, 2021
by
mibaumgartner
Browse files
option to consoliadte and sweep different ckpt
parent
2a8e54b4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
59 deletions
+13
-59
nndet/conf/train/v001.yaml
nndet/conf/train/v001.yaml
+6
-0
nndet/inference/helper.py
nndet/inference/helper.py
+2
-2
nndet/inference/loading.py
nndet/inference/loading.py
+3
-55
nndet/ptmodule/retinaunet/base.py
nndet/ptmodule/retinaunet/base.py
+2
-2
No files found.
nndet/conf/train/v001.yaml
View file @
219a52a8
...
@@ -21,6 +21,11 @@ augment_cfg:
...
@@ -21,6 +21,11 @@ augment_cfg:
num_cached_per_thread
:
2
num_cached_per_thread
:
2
multiprocessing
:
True
# only deactivate this if debugging
multiprocessing
:
True
# only deactivate this if debugging
# Additional overwrites
# patch_size; Default: plan
# batch_size; Default: plan
# splits; Default: splits_final
trainer_cfg
:
trainer_cfg
:
gpus
:
1
# number of gpus
gpus
:
1
# number of gpus
accelerator
:
ddp
# distributed backend
accelerator
:
ddp
# distributed backend
...
@@ -52,6 +57,7 @@ trainer_cfg:
...
@@ -52,6 +57,7 @@ trainer_cfg:
poly_gamma
:
0.9
poly_gamma
:
0.9
swa_epochs
:
10
# number of epochs to run swa with cyclic learning rate
swa_epochs
:
10
# number of epochs to run swa with cyclic learning rate
# sweep_ckpt: Select checkpoint identifier for sweeping. Default "last".
model_cfg
:
model_cfg
:
encoder_kwargs
:
{}
# keyword arguments passed to encoder
encoder_kwargs
:
{}
# keyword arguments passed to encoder
...
...
nndet/inference/helper.py
View file @
219a52a8
...
@@ -23,7 +23,7 @@ from loguru import logger
...
@@ -23,7 +23,7 @@ from loguru import logger
from
nndet.utils.tensor
import
to_numpy
from
nndet.utils.tensor
import
to_numpy
from
nndet.io.load
import
load_pickle
,
save_pickle
from
nndet.io.load
import
load_pickle
,
save_pickle
from
nndet.io.paths
import
Pathlike
,
get_case_id_from_path
from
nndet.io.paths
import
Pathlike
,
get_case_id_from_path
from
nndet.inference.loading
import
load_
time_ensemble
from
nndet.inference.loading
import
load_
final_model
def
predict_dir
(
def
predict_dir
(
...
@@ -32,7 +32,7 @@ def predict_dir(
...
@@ -32,7 +32,7 @@ def predict_dir(
cfg
:
dict
,
cfg
:
dict
,
plan
:
dict
,
plan
:
dict
,
source_models
:
Path
,
source_models
:
Path
,
model_fn
:
Callable
[[
Path
,
dict
,
dict
,
int
],
Sequence
[
dict
]]
=
load_
time_ensemble
,
model_fn
:
Callable
[[
Path
,
dict
,
dict
,
int
],
Sequence
[
dict
]]
=
load_
final_model
,
num_models
:
int
=
None
,
num_models
:
int
=
None
,
num_tta_transforms
:
int
=
None
,
num_tta_transforms
:
int
=
None
,
restore
:
bool
=
False
,
restore
:
bool
=
False
,
...
...
nndet/inference/loading.py
View file @
219a52a8
...
@@ -27,14 +27,10 @@ from nndet.io.paths import Pathlike
...
@@ -27,14 +27,10 @@ from nndet.io.paths import Pathlike
def
get_loader_fn
(
mode
:
str
,
**
kwargs
):
def
get_loader_fn
(
mode
:
str
,
**
kwargs
):
if
mode
==
"best"
:
if
mode
.
lower
()
==
"all"
:
load_fn
=
partial
(
load_time_ensemble
,
**
kwargs
)
load_fn
=
load_all_models
elif
mode
==
"final"
:
load_fn
=
partial
(
load_final_model
,
**
kwargs
)
elif
mode
==
"latest"
:
load_fn
=
partial
(
load_final_model
,
identifier
=
"latest"
,
**
kwargs
)
else
:
else
:
raise
ValueError
(
f
"Unknown mode
{
mode
}
"
)
load_fn
=
partial
(
load_final_model
,
identifier
=
mode
,
**
kwargs
)
return
load_fn
return
load_fn
...
@@ -61,54 +57,6 @@ def get_latest_model(base_dir: Pathlike, fold: int = 0) -> Optional[Path]:
...
@@ -61,54 +57,6 @@ def get_latest_model(base_dir: Pathlike, fold: int = 0) -> Optional[Path]:
return
None
return
None
# TODO: update
def
load_time_ensemble
(
source_models
:
Path
,
cfg
:
dict
,
plan
:
dict
,
num_models
:
int
=
None
,
)
->
Sequence
[
dict
]:
"""
Load time ensembled models
Args:
source_models: path to directory where models are saved
cfg: config used for experiment
`model`: name of model in DETECTION_REGISTRY
plan: plan used for training
num_models: number of models to load
Returns:
Sequence[dict]: loaded models
`model`: loaded model
`rank`: rank of model
"""
logger
.
info
(
"Loading time ensemble"
)
model_names
=
list
(
source_models
.
glob
(
'model_best*.ckpt'
))
if
not
model_names
:
raise
RuntimeError
(
f
"Did not find any models in
{
source_models
}
"
)
models
=
[]
for
path
in
model_names
:
model
=
MODULE_REGISTRY
[
cfg
[
"module"
]](
model_cfg
=
cfg
[
"model_cfg"
],
trainer_cfg
=
cfg
[
"trainer_cfg"
],
plan
=
plan
,
)
state_dict
=
torch
.
load
(
path
,
map_location
=
"cpu"
)[
"state_dict"
]
t
=
model
.
load_state_dict
(
state_dict
)
logger
.
info
(
f
"Loaded
{
path
}
with
{
t
}
"
)
model
.
float
()
model
.
eval
()
rank
=
int
(
str
(
path
).
rsplit
(
os
.
sep
,
1
)[
-
1
][
10
])
models
.
append
({
"model"
:
model
.
cpu
(),
"rank"
:
rank
})
if
num_models
is
not
None
:
models
=
models
[:
num_models
]
logger
.
info
(
f
"Using
{
len
(
models
)
}
models for for inference."
)
return
models
def
load_final_model
(
def
load_final_model
(
source_models
:
Path
,
source_models
:
Path
,
cfg
:
dict
,
cfg
:
dict
,
...
...
nndet/ptmodule/retinaunet/base.py
View file @
219a52a8
...
@@ -58,7 +58,7 @@ from nndet.training.learning_rate import LinearWarmupPolyLR
...
@@ -58,7 +58,7 @@ from nndet.training.learning_rate import LinearWarmupPolyLR
from
nndet.inference.predictor
import
Predictor
from
nndet.inference.predictor
import
Predictor
from
nndet.inference.sweeper
import
BoxSweeper
from
nndet.inference.sweeper
import
BoxSweeper
from
nndet.inference.transforms
import
get_tta_transforms
,
Inference2D
from
nndet.inference.transforms
import
get_tta_transforms
,
Inference2D
from
nndet.inference.loading
import
load_final_model
from
nndet.inference.loading
import
get_loader_fn
from
nndet.inference.helper
import
predict_dir
from
nndet.inference.helper
import
predict_dir
from
nndet.inference.ensembler.segmentation
import
SegmentationEnsembler
from
nndet.inference.ensembler.segmentation
import
SegmentationEnsembler
from
nndet.inference.ensembler.detection
import
BoxEnsemblerSelective
from
nndet.inference.ensembler.detection
import
BoxEnsemblerSelective
...
@@ -762,7 +762,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -762,7 +762,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
num_tta_transforms
=
None
,
num_tta_transforms
=
None
,
case_ids
=
case_ids
,
case_ids
=
case_ids
,
save_state
=
True
,
save_state
=
True
,
model_fn
=
load_final_model
,
model_fn
=
get_loader_fn
(
mode
=
self
.
trainer_cfg
.
get
(
"sweep_ckpt"
,
"last"
))
,
**
kwargs
,
**
kwargs
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment