Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
724ba674
Commit
724ba674
authored
May 05, 2021
by
mibaumgartner
Browse files
minor improvements in predict script
parent
0f3a95f3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
3 deletions
+29
-3
scripts/predict.py
scripts/predict.py
+29
-3
No files found.
scripts/predict.py
View file @
724ba674
...
@@ -36,6 +36,7 @@ def run(cfg: dict,
...
@@ -36,6 +36,7 @@ def run(cfg: dict,
process
:
bool
=
True
,
process
:
bool
=
True
,
num_models
:
int
=
None
,
num_models
:
int
=
None
,
num_tta_transforms
:
int
=
None
,
num_tta_transforms
:
int
=
None
,
test_split
:
bool
=
False
,
):
):
"""
"""
Run inference pipeline
Run inference pipeline
...
@@ -48,6 +49,10 @@ def run(cfg: dict,
...
@@ -48,6 +49,10 @@ def run(cfg: dict,
are used
are used
num_tta_transforms: number of tta transformation; if None the maximum
num_tta_transforms: number of tta transformation; if None the maximum
number of transformation is used
number of transformation is used
test_split: Typical usage of nnDetection will never require
this option! Predict an already preprocessed split of the original
training data. The 'test' split needs to be located in fold 0
of a manually created split file.
"""
"""
plan
=
load_pickle
(
training_dir
/
"plan_inference.pkl"
)
plan
=
load_pickle
(
training_dir
/
"plan_inference.pkl"
)
...
@@ -68,7 +73,13 @@ def run(cfg: dict,
...
@@ -68,7 +73,13 @@ def run(cfg: dict,
)
)
prediction_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
prediction_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
source_dir
=
preprocessed_output_dir
/
plan
[
"data_identifier"
]
/
"imagesTs"
if
test_split
:
source_dir
=
preprocessed_output_dir
/
plan
[
"data_identifier"
]
/
"imagesTr"
case_ids
=
load_pickle
(
training_dir
/
"splits.pkl"
)[
0
][
"test"
]
else
:
source_dir
=
preprocessed_output_dir
/
plan
[
"data_identifier"
]
/
"imagesTs"
case_ids
=
None
predict_dir
(
source_dir
=
source_dir
,
predict_dir
(
source_dir
=
source_dir
,
target_dir
=
prediction_dir
,
target_dir
=
prediction_dir
,
cfg
=
cfg
,
cfg
=
cfg
,
...
@@ -78,7 +89,8 @@ def run(cfg: dict,
...
@@ -78,7 +89,8 @@ def run(cfg: dict,
num_tta_transforms
=
num_tta_transforms
,
num_tta_transforms
=
num_tta_transforms
,
model_fn
=
load_all_models
,
model_fn
=
load_all_models
,
restore
=
True
,
restore
=
True
,
# do_seg=True, # TODO: change this...
case_ids
=
case_ids
,
**
cfg
.
get
(
"inference_kwargs"
,
{}),
)
)
...
@@ -129,7 +141,10 @@ def main():
...
@@ -129,7 +141,10 @@ def main():
help
=
"number of tta transforms (per default most tta are chosen)"
,
help
=
"number of tta transforms (per default most tta are chosen)"
,
required
=
False
)
required
=
False
)
parser
.
add_argument
(
'-o'
,
'--overwrites'
,
type
=
str
,
nargs
=
'+'
,
parser
.
add_argument
(
'-o'
,
'--overwrites'
,
type
=
str
,
nargs
=
'+'
,
help
=
"overwrites for config file"
,
default
=
None
,
help
=
(
"overwrites for config file. "
"inference_kwargs can be used to add additional "
"keyword arguments to inference."
),
default
=
None
,
required
=
False
)
required
=
False
)
parser
.
add_argument
(
'--no_preprocess'
,
help
=
"Preprocess test data"
,
action
=
'store_false'
)
parser
.
add_argument
(
'--no_preprocess'
,
help
=
"Preprocess test data"
,
action
=
'store_false'
)
parser
.
add_argument
(
'--force_args'
,
parser
.
add_argument
(
'--force_args'
,
...
@@ -137,6 +152,12 @@ def main():
...
@@ -137,6 +152,12 @@ def main():
"and fold might differ from the original one. "
"and fold might differ from the original one. "
"This forces an overwrite to the passed in arguments of"
"This forces an overwrite to the passed in arguments of"
" this function. This can be dangerous!"
),
action
=
'store_true'
)
" this function. This can be dangerous!"
),
action
=
'store_true'
)
parser
.
add_argument
(
'--test_split'
,
help
=
(
"Typical usage of nnDetection will never require "
"this option! Predict an already preprocessed "
"split of the original training data. "
"The 'test' split needs to be located in fold 0 "
"of a manually created split file."
))
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
model
=
args
.
model
model
=
args
.
model
...
@@ -146,12 +167,16 @@ def main():
...
@@ -146,12 +167,16 @@ def main():
num_tta_transforms
=
args
.
num_tta
num_tta_transforms
=
args
.
num_tta
ov
=
args
.
overwrites
ov
=
args
.
overwrites
force_args
=
args
.
force_args
force_args
=
args
.
force_args
test_split
=
args
.
test_split
task_name
=
get_task
(
task
,
name
=
True
)
task_name
=
get_task
(
task
,
name
=
True
)
task_model_dir
=
Path
(
os
.
getenv
(
"det_models"
))
task_model_dir
=
Path
(
os
.
getenv
(
"det_models"
))
training_dir
=
get_training_dir
(
task_model_dir
/
task_name
/
model
,
fold
)
training_dir
=
get_training_dir
(
task_model_dir
/
task_name
/
model
,
fold
)
process
=
args
.
no_preprocess
process
=
args
.
no_preprocess
if
test_split
:
raise
ValueError
(
"When using the test split option raw data is not "
"supported. Need to add --no_preprocess flag!"
)
cfg
=
OmegaConf
.
load
(
str
(
training_dir
/
"config.yaml"
))
cfg
=
OmegaConf
.
load
(
str
(
training_dir
/
"config.yaml"
))
...
@@ -170,6 +195,7 @@ def main():
...
@@ -170,6 +195,7 @@ def main():
process
=
process
,
process
=
process
,
num_models
=
num_models
,
num_models
=
num_models
,
num_tta_transforms
=
num_tta_transforms
,
num_tta_transforms
=
num_tta_transforms
,
test_split
=
test_split
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment