Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
8607cb0f
Commit
8607cb0f
authored
Nov 24, 2022
by
Baumgartner, Michael
Browse files
Merge remote-tracking branch 'origin/0000_project' into main
parents
1044ace5
ca7e0f11
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
44 additions
and
6 deletions
+44
-6
nndet/core/boxes/nms.py
nndet/core/boxes/nms.py
+6
-1
nndet/evaluator/det.py
nndet/evaluator/det.py
+3
-1
nndet/inference/loading.py
nndet/inference/loading.py
+4
-1
nndet/io/datamodule/base.py
nndet/io/datamodule/base.py
+7
-2
nndet/io/paths.py
nndet/io/paths.py
+5
-1
scripts/train.py
scripts/train.py
+5
-0
scripts/utils.py
scripts/utils.py
+13
-0
setup.py
setup.py
+1
-0
No files found.
nndet/core/boxes/nms.py
View file @
8607cb0f
...
...
@@ -15,11 +15,16 @@ limitations under the License.
"""
import
torch
from
loguru
import
logger
from
torch
import
Tensor
from
torch.cuda.amp
import
autocast
from
torchvision.ops.boxes
import
nms
as
nms_2d
from
nndet._C
import
nms
as
nms_gpu
try
:
from
nndet._C
import
nms
as
nms_gpu
except
ImportError
:
logger
.
warning
(
"nnDetection was not build with GPU support!"
)
nms_gpu
=
None
from
nndet.core.boxes.ops
import
box_iou
...
...
nndet/evaluator/det.py
View file @
8607cb0f
...
...
@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
def
__init__
(
self
,
metrics
:
Sequence
[
DetectionMetric
],
iou_fn
:
Callable
[[
np
.
ndarray
,
np
.
ndarray
],
np
.
ndarray
]
=
box_iou_np
,
match_fn
:
Callable
=
matching_batch
,
max_detections
:
int
=
100
,
):
"""
...
...
@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
max_detections (int): number of maximum detections per image (reduces computation)
"""
self
.
iou_fn
=
iou_fn
self
.
match_fn
=
match_fn
self
.
max_detections
=
max_detections
self
.
metrics
=
metrics
self
.
results_list
=
[]
# store results of each image
...
...
@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
n
=
[
0
if
gt_boxes_img
.
size
==
0
else
gt_boxes_img
.
shape
[
0
]
for
gt_boxes_img
in
gt_boxes
]
gt_ignore
=
[
np
.
zeros
(
_n
).
reshape
(
-
1
)
for
_n
in
n
]
self
.
results_list
.
extend
(
matching_batch
(
self
.
results_list
.
extend
(
self
.
match_fn
(
self
.
iou_fn
,
self
.
iou_thresholds
,
pred_boxes
=
pred_boxes
,
pred_classes
=
pred_classes
,
pred_scores
=
pred_scores
,
gt_boxes
=
gt_boxes
,
gt_classes
=
gt_classes
,
gt_ignore
=
gt_ignore
,
max_detections
=
self
.
max_detections
))
...
...
nndet/inference/loading.py
View file @
8607cb0f
...
...
@@ -22,7 +22,6 @@ from typing import Sequence, Optional
import
torch
from
loguru
import
logger
from
nndet.ptmodule
import
MODULE_REGISTRY
from
nndet.io.paths
import
Pathlike
...
...
@@ -80,6 +79,8 @@ def load_final_model(
`model`: loaded model
`rank`: rank is always 0
"""
from
nndet.ptmodule
import
MODULE_REGISTRY
assert
num_models
==
1
,
f
"load_final_model only supports num_models=1, found
{
num_models
}
"
logger
.
info
(
f
"Loading
{
identifier
}
model"
)
...
...
@@ -123,6 +124,8 @@ def load_all_models(
`model`: loaded model
`rank`: rank of model
"""
from
nndet.ptmodule
import
MODULE_REGISTRY
model_names
=
list
(
source_models
.
glob
(
'*.ckpt'
))
if
not
model_names
:
raise
RuntimeError
(
f
"Did not find any models in
{
source_models
}
"
)
...
...
nndet/io/datamodule/base.py
View file @
8607cb0f
...
...
@@ -57,8 +57,13 @@ class BaseModule(pl.LightningDataModule):
self
.
fold
=
fold
self
.
preprocessed_dir
=
self
.
data_dir
.
parent
.
parent
self
.
splits_file
=
self
.
augment_cfg
.
get
(
"splits_final"
,
"splits_final.pkl"
)
if
"splits"
in
self
.
augment_cfg
:
self
.
splits_file
=
self
.
augment_cfg
[
"splits"
]
elif
"splits_final"
in
self
.
augment_cfg
:
self
.
splits_file
=
self
.
augment_cfg
[
"splits_final"
]
else
:
self
.
splits_file
=
"splits_final"
self
.
dataset_tr
=
{}
self
.
dataset_val
=
{}
...
...
nndet/io/paths.py
View file @
8607cb0f
...
...
@@ -171,7 +171,11 @@ def get_case_id_from_file(file_name: str, remove_modality: bool = True) -> str:
Returns:
str: name of file without ending
"""
file_name
=
file_name
.
split
(
'.'
)[
0
]
if
file_name
.
endswith
(
".nii.gz"
):
file_name
=
file_name
.
rsplit
(
"."
,
2
)[
0
]
else
:
file_name
=
file_name
.
rsplit
(
"."
,
1
)[
0
]
if
remove_modality
:
file_name
=
file_name
[:
-
5
]
return
file_name
...
...
scripts/train.py
View file @
8607cb0f
...
...
@@ -18,6 +18,7 @@ import os
import
sys
import
socket
import
argparse
import
importlib
from
pathlib
import
Path
from
datetime
import
datetime
from
typing
import
List
...
...
@@ -347,6 +348,10 @@ def _sweep(
cfg
=
OmegaConf
.
load
(
str
(
train_dir
/
"config.yaml"
))
os
.
chdir
(
str
(
train_dir
))
for
imp
in
cfg
.
get
(
"additional_imports"
,
[]):
print
(
f
"Additional import found
{
imp
}
"
)
importlib
.
import_module
(
imp
)
logger
.
remove
()
logger
.
add
(
sys
.
stdout
,
format
=
"{level} {message}"
,
level
=
"INFO"
)
log_file
=
Path
(
os
.
getcwd
())
/
"sweep.log"
...
...
scripts/utils.py
View file @
8607cb0f
...
...
@@ -186,6 +186,19 @@ def unpack():
unpack_dataset
(
p
,
num_processes
,
False
)
def
hydra_searchpath
():
from
hydra
import
compose
as
hydra_compose
from
hydra
import
initialize_config_module
initialize_config_module
(
config_module
=
"nndet.conf"
)
cfg
=
hydra_compose
(
"config.yaml"
,
return_hydra_config
=
True
)
print
(
"Found config sources::"
)
print
(
"----------------------"
)
for
s
in
cfg
.
hydra
.
runtime
.
config_sources
:
print
(
s
)
def
env
():
import
os
import
torch
...
...
setup.py
View file @
8607cb0f
...
...
@@ -129,6 +129,7 @@ setup(
'nndet_seg2nii = scripts.utils:seg2nii'
,
'nndet_unpack = scripts.utils:unpack'
,
'nndet_env = scripts.utils:env'
,
'nndet_searchpath = scripts.utils:hydra_searchpath'
]
},
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment