Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
8607cb0f
Commit
8607cb0f
authored
Nov 24, 2022
by
Baumgartner, Michael
Browse files
Merge remote-tracking branch 'origin/0000_project' into main
parents
1044ace5
ca7e0f11
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
44 additions
and
6 deletions
+44
-6
nndet/core/boxes/nms.py
nndet/core/boxes/nms.py
+6
-1
nndet/evaluator/det.py
nndet/evaluator/det.py
+3
-1
nndet/inference/loading.py
nndet/inference/loading.py
+4
-1
nndet/io/datamodule/base.py
nndet/io/datamodule/base.py
+7
-2
nndet/io/paths.py
nndet/io/paths.py
+5
-1
scripts/train.py
scripts/train.py
+5
-0
scripts/utils.py
scripts/utils.py
+13
-0
setup.py
setup.py
+1
-0
No files found.
nndet/core/boxes/nms.py
View file @
8607cb0f
...
@@ -15,11 +15,16 @@ limitations under the License.
...
@@ -15,11 +15,16 @@ limitations under the License.
"""
"""
import
torch
import
torch
from
loguru
import
logger
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.cuda.amp
import
autocast
from
torch.cuda.amp
import
autocast
from
torchvision.ops.boxes
import
nms
as
nms_2d
from
torchvision.ops.boxes
import
nms
as
nms_2d
from
nndet._C
import
nms
as
nms_gpu
try
:
from
nndet._C
import
nms
as
nms_gpu
except
ImportError
:
logger
.
warning
(
"nnDetection was not build with GPU support!"
)
nms_gpu
=
None
from
nndet.core.boxes.ops
import
box_iou
from
nndet.core.boxes.ops
import
box_iou
...
...
nndet/evaluator/det.py
View file @
8607cb0f
...
@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
...
@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
def
__init__
(
self
,
def
__init__
(
self
,
metrics
:
Sequence
[
DetectionMetric
],
metrics
:
Sequence
[
DetectionMetric
],
iou_fn
:
Callable
[[
np
.
ndarray
,
np
.
ndarray
],
np
.
ndarray
]
=
box_iou_np
,
iou_fn
:
Callable
[[
np
.
ndarray
,
np
.
ndarray
],
np
.
ndarray
]
=
box_iou_np
,
match_fn
:
Callable
=
matching_batch
,
max_detections
:
int
=
100
,
max_detections
:
int
=
100
,
):
):
"""
"""
...
@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
...
@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
max_detections (int): number of maximum detections per image (reduces computation)
max_detections (int): number of maximum detections per image (reduces computation)
"""
"""
self
.
iou_fn
=
iou_fn
self
.
iou_fn
=
iou_fn
self
.
match_fn
=
match_fn
self
.
max_detections
=
max_detections
self
.
max_detections
=
max_detections
self
.
metrics
=
metrics
self
.
metrics
=
metrics
self
.
results_list
=
[]
# store results of each image
self
.
results_list
=
[]
# store results of each image
...
@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
...
@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
n
=
[
0
if
gt_boxes_img
.
size
==
0
else
gt_boxes_img
.
shape
[
0
]
for
gt_boxes_img
in
gt_boxes
]
n
=
[
0
if
gt_boxes_img
.
size
==
0
else
gt_boxes_img
.
shape
[
0
]
for
gt_boxes_img
in
gt_boxes
]
gt_ignore
=
[
np
.
zeros
(
_n
).
reshape
(
-
1
)
for
_n
in
n
]
gt_ignore
=
[
np
.
zeros
(
_n
).
reshape
(
-
1
)
for
_n
in
n
]
self
.
results_list
.
extend
(
matching_batch
(
self
.
results_list
.
extend
(
self
.
match_fn
(
self
.
iou_fn
,
self
.
iou_thresholds
,
pred_boxes
=
pred_boxes
,
pred_classes
=
pred_classes
,
self
.
iou_fn
,
self
.
iou_thresholds
,
pred_boxes
=
pred_boxes
,
pred_classes
=
pred_classes
,
pred_scores
=
pred_scores
,
gt_boxes
=
gt_boxes
,
gt_classes
=
gt_classes
,
gt_ignore
=
gt_ignore
,
pred_scores
=
pred_scores
,
gt_boxes
=
gt_boxes
,
gt_classes
=
gt_classes
,
gt_ignore
=
gt_ignore
,
max_detections
=
self
.
max_detections
))
max_detections
=
self
.
max_detections
))
...
...
nndet/inference/loading.py
View file @
8607cb0f
...
@@ -22,7 +22,6 @@ from typing import Sequence, Optional
...
@@ -22,7 +22,6 @@ from typing import Sequence, Optional
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
nndet.ptmodule
import
MODULE_REGISTRY
from
nndet.io.paths
import
Pathlike
from
nndet.io.paths
import
Pathlike
...
@@ -80,6 +79,8 @@ def load_final_model(
...
@@ -80,6 +79,8 @@ def load_final_model(
`model`: loaded model
`model`: loaded model
`rank`: rank is always 0
`rank`: rank is always 0
"""
"""
from
nndet.ptmodule
import
MODULE_REGISTRY
assert
num_models
==
1
,
f
"load_final_model only supports num_models=1, found
{
num_models
}
"
assert
num_models
==
1
,
f
"load_final_model only supports num_models=1, found
{
num_models
}
"
logger
.
info
(
f
"Loading
{
identifier
}
model"
)
logger
.
info
(
f
"Loading
{
identifier
}
model"
)
...
@@ -123,6 +124,8 @@ def load_all_models(
...
@@ -123,6 +124,8 @@ def load_all_models(
`model`: loaded model
`model`: loaded model
`rank`: rank of model
`rank`: rank of model
"""
"""
from
nndet.ptmodule
import
MODULE_REGISTRY
model_names
=
list
(
source_models
.
glob
(
'*.ckpt'
))
model_names
=
list
(
source_models
.
glob
(
'*.ckpt'
))
if
not
model_names
:
if
not
model_names
:
raise
RuntimeError
(
f
"Did not find any models in
{
source_models
}
"
)
raise
RuntimeError
(
f
"Did not find any models in
{
source_models
}
"
)
...
...
nndet/io/datamodule/base.py
View file @
8607cb0f
...
@@ -57,8 +57,13 @@ class BaseModule(pl.LightningDataModule):
...
@@ -57,8 +57,13 @@ class BaseModule(pl.LightningDataModule):
self
.
fold
=
fold
self
.
fold
=
fold
self
.
preprocessed_dir
=
self
.
data_dir
.
parent
.
parent
self
.
preprocessed_dir
=
self
.
data_dir
.
parent
.
parent
self
.
splits_file
=
self
.
augment_cfg
.
get
(
"splits_final"
,
"splits_final.pkl"
)
if
"splits"
in
self
.
augment_cfg
:
self
.
splits_file
=
self
.
augment_cfg
[
"splits"
]
elif
"splits_final"
in
self
.
augment_cfg
:
self
.
splits_file
=
self
.
augment_cfg
[
"splits_final"
]
else
:
self
.
splits_file
=
"splits_final"
self
.
dataset_tr
=
{}
self
.
dataset_tr
=
{}
self
.
dataset_val
=
{}
self
.
dataset_val
=
{}
...
...
nndet/io/paths.py
View file @
8607cb0f
...
@@ -171,7 +171,11 @@ def get_case_id_from_file(file_name: str, remove_modality: bool = True) -> str:
...
@@ -171,7 +171,11 @@ def get_case_id_from_file(file_name: str, remove_modality: bool = True) -> str:
Returns:
Returns:
str: name of file without ending
str: name of file without ending
"""
"""
file_name
=
file_name
.
split
(
'.'
)[
0
]
if
file_name
.
endswith
(
".nii.gz"
):
file_name
=
file_name
.
rsplit
(
"."
,
2
)[
0
]
else
:
file_name
=
file_name
.
rsplit
(
"."
,
1
)[
0
]
if
remove_modality
:
if
remove_modality
:
file_name
=
file_name
[:
-
5
]
file_name
=
file_name
[:
-
5
]
return
file_name
return
file_name
...
...
scripts/train.py
View file @
8607cb0f
...
@@ -18,6 +18,7 @@ import os
...
@@ -18,6 +18,7 @@ import os
import
sys
import
sys
import
socket
import
socket
import
argparse
import
argparse
import
importlib
from
pathlib
import
Path
from
pathlib
import
Path
from
datetime
import
datetime
from
datetime
import
datetime
from
typing
import
List
from
typing
import
List
...
@@ -347,6 +348,10 @@ def _sweep(
...
@@ -347,6 +348,10 @@ def _sweep(
cfg
=
OmegaConf
.
load
(
str
(
train_dir
/
"config.yaml"
))
cfg
=
OmegaConf
.
load
(
str
(
train_dir
/
"config.yaml"
))
os
.
chdir
(
str
(
train_dir
))
os
.
chdir
(
str
(
train_dir
))
for
imp
in
cfg
.
get
(
"additional_imports"
,
[]):
print
(
f
"Additional import found
{
imp
}
"
)
importlib
.
import_module
(
imp
)
logger
.
remove
()
logger
.
remove
()
logger
.
add
(
sys
.
stdout
,
format
=
"{level} {message}"
,
level
=
"INFO"
)
logger
.
add
(
sys
.
stdout
,
format
=
"{level} {message}"
,
level
=
"INFO"
)
log_file
=
Path
(
os
.
getcwd
())
/
"sweep.log"
log_file
=
Path
(
os
.
getcwd
())
/
"sweep.log"
...
...
scripts/utils.py
View file @
8607cb0f
...
@@ -186,6 +186,19 @@ def unpack():
...
@@ -186,6 +186,19 @@ def unpack():
unpack_dataset
(
p
,
num_processes
,
False
)
unpack_dataset
(
p
,
num_processes
,
False
)
def
hydra_searchpath
():
from
hydra
import
compose
as
hydra_compose
from
hydra
import
initialize_config_module
initialize_config_module
(
config_module
=
"nndet.conf"
)
cfg
=
hydra_compose
(
"config.yaml"
,
return_hydra_config
=
True
)
print
(
"Found config sources::"
)
print
(
"----------------------"
)
for
s
in
cfg
.
hydra
.
runtime
.
config_sources
:
print
(
s
)
def
env
():
def
env
():
import
os
import
os
import
torch
import
torch
...
...
setup.py
View file @
8607cb0f
...
@@ -129,6 +129,7 @@ setup(
...
@@ -129,6 +129,7 @@ setup(
'nndet_seg2nii = scripts.utils:seg2nii'
,
'nndet_seg2nii = scripts.utils:seg2nii'
,
'nndet_unpack = scripts.utils:unpack'
,
'nndet_unpack = scripts.utils:unpack'
,
'nndet_env = scripts.utils:env'
,
'nndet_env = scripts.utils:env'
,
'nndet_searchpath = scripts.utils:hydra_searchpath'
]
]
},
},
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment