Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
2a8e54b4
Commit
2a8e54b4
authored
May 30, 2021
by
mibaumgartner
Browse files
add checkpoint identifier for ensembling
parent
b32e7c19
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
4 deletions
+12
-4
scripts/consolidate.py
scripts/consolidate.py
+11
-3
scripts/predict.py
scripts/predict.py
+1
-1
No files found.
scripts/consolidate.py
View file @
2a8e54b4
...
...
@@ -32,17 +32,18 @@ from nndet.inference.ensembler.base import extract_results
from
nndet.io
import
get_task
,
load_pickle
,
save_pickle
def
consolidate_models
(
source_dirs
:
Sequence
[
Path
],
target_dir
:
Path
):
def
consolidate_models
(
source_dirs
:
Sequence
[
Path
],
target_dir
:
Path
,
ckpt
:
str
):
"""
Copy final models from folds into consolidated folder
Args:
source_dirs: directory of each fold to consolidate
target_dir: directory to save models to
ckpt: checkpoint identifier to select models for ensembling
"""
for
fold
,
sd
in
enumerate
(
source_dirs
):
model_paths
=
list
(
sd
.
glob
(
'*.ckpt'
))
found_models
=
[
mp
for
mp
in
model_paths
if
"last"
in
str
(
mp
.
stem
)]
found_models
=
[
mp
for
mp
in
model_paths
if
ckpt
in
str
(
mp
.
stem
)]
assert
len
(
found_models
)
==
1
,
f
"Found wrong number of models,
{
found_models
}
"
model_path
=
found_models
[
0
]
assert
f
"fold
{
fold
}
"
in
str
(
model_path
.
parent
.
stem
),
f
"Expected fold
{
fold
}
but found
{
model_path
}
"
...
...
@@ -108,6 +109,9 @@ def main():
parser
.
add_argument
(
'--sweep_instances'
,
action
=
"store_true"
,
help
=
"Sweep for best parameters for instance segmentation based models"
,
)
parser
.
add_argument
(
'--ckpt'
,
type
=
str
,
default
=
"last"
,
required
=
False
,
help
=
"Define identifier of checkpoint for consolidation. "
"Use this with care!"
)
args
=
parser
.
parse_args
()
model
=
args
.
model
...
...
@@ -120,6 +124,7 @@ def main():
sweep_boxes
=
args
.
sweep_boxes
sweep_instances
=
args
.
sweep_instances
ckpt
=
args
.
ckpt
if
consolidate
==
"export"
and
not
(
sweep_boxes
or
sweep_instances
):
raise
ValueError
(
"Export needs new parameter sweep! Actiate one of the sweep "
...
...
@@ -142,7 +147,10 @@ def main():
# model consolidation
if
do_model_consolidation
:
logger
.
info
(
"Consolidate models"
)
consolidate_models
(
training_dirs
,
target_dir
)
if
ckpt
!=
"last"
:
logger
.
warning
(
f
"Found ckpt overwrite
{
ckpt
}
, this is not the default, "
"this can drastically influence the performance!"
)
consolidate_models
(
training_dirs
,
target_dir
,
ckpt
)
# consolidate predictions
logger
.
info
(
"Consolidate predictions"
)
...
...
scripts/predict.py
View file @
2a8e54b4
...
...
@@ -81,7 +81,7 @@ def run(cfg: dict,
else
:
source_dir
=
preprocessed_output_dir
/
plan
[
"data_identifier"
]
/
"imagesTs"
case_ids
=
None
predict_dir
(
source_dir
=
source_dir
,
target_dir
=
prediction_dir
,
cfg
=
cfg
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment