Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
0cdad98d
Commit
0cdad98d
authored
Jan 24, 2022
by
mibaumgartner
Browse files
utils
parent
9853d3e4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
26 additions
and
2 deletions
+26
-2
nndet/evaluator/det.py
nndet/evaluator/det.py
+3
-1
nndet/inference/loading.py
nndet/inference/loading.py
+4
-1
scripts/train.py
scripts/train.py
+5
-0
scripts/utils.py
scripts/utils.py
+13
-0
setup.py
setup.py
+1
-0
No files found.
nndet/evaluator/det.py
View file @
0cdad98d
...
@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
...
@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
def
__init__
(
self
,
def
__init__
(
self
,
metrics
:
Sequence
[
DetectionMetric
],
metrics
:
Sequence
[
DetectionMetric
],
iou_fn
:
Callable
[[
np
.
ndarray
,
np
.
ndarray
],
np
.
ndarray
]
=
box_iou_np
,
iou_fn
:
Callable
[[
np
.
ndarray
,
np
.
ndarray
],
np
.
ndarray
]
=
box_iou_np
,
match_fn
:
Callable
=
matching_batch
,
max_detections
:
int
=
100
,
max_detections
:
int
=
100
,
):
):
"""
"""
...
@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
...
@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
max_detections (int): number of maximum detections per image (reduces computation)
max_detections (int): number of maximum detections per image (reduces computation)
"""
"""
self
.
iou_fn
=
iou_fn
self
.
iou_fn
=
iou_fn
self
.
match_fn
=
match_fn
self
.
max_detections
=
max_detections
self
.
max_detections
=
max_detections
self
.
metrics
=
metrics
self
.
metrics
=
metrics
self
.
results_list
=
[]
# store results of each image
self
.
results_list
=
[]
# store results of each image
...
@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
...
@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
n
=
[
0
if
gt_boxes_img
.
size
==
0
else
gt_boxes_img
.
shape
[
0
]
for
gt_boxes_img
in
gt_boxes
]
n
=
[
0
if
gt_boxes_img
.
size
==
0
else
gt_boxes_img
.
shape
[
0
]
for
gt_boxes_img
in
gt_boxes
]
gt_ignore
=
[
np
.
zeros
(
_n
).
reshape
(
-
1
)
for
_n
in
n
]
gt_ignore
=
[
np
.
zeros
(
_n
).
reshape
(
-
1
)
for
_n
in
n
]
self
.
results_list
.
extend
(
matching_batch
(
self
.
results_list
.
extend
(
self
.
match_fn
(
self
.
iou_fn
,
self
.
iou_thresholds
,
pred_boxes
=
pred_boxes
,
pred_classes
=
pred_classes
,
self
.
iou_fn
,
self
.
iou_thresholds
,
pred_boxes
=
pred_boxes
,
pred_classes
=
pred_classes
,
pred_scores
=
pred_scores
,
gt_boxes
=
gt_boxes
,
gt_classes
=
gt_classes
,
gt_ignore
=
gt_ignore
,
pred_scores
=
pred_scores
,
gt_boxes
=
gt_boxes
,
gt_classes
=
gt_classes
,
gt_ignore
=
gt_ignore
,
max_detections
=
self
.
max_detections
))
max_detections
=
self
.
max_detections
))
...
...
nndet/inference/loading.py
View file @
0cdad98d
...
@@ -22,7 +22,6 @@ from typing import Sequence, Optional
...
@@ -22,7 +22,6 @@ from typing import Sequence, Optional
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
nndet.ptmodule
import
MODULE_REGISTRY
from
nndet.io.paths
import
Pathlike
from
nndet.io.paths
import
Pathlike
...
@@ -80,6 +79,8 @@ def load_final_model(
...
@@ -80,6 +79,8 @@ def load_final_model(
`model`: loaded model
`model`: loaded model
`rank`: rank is always 0
`rank`: rank is always 0
"""
"""
from
nndet.ptmodule
import
MODULE_REGISTRY
assert
num_models
==
1
,
f
"load_final_model only supports num_models=1, found
{
num_models
}
"
assert
num_models
==
1
,
f
"load_final_model only supports num_models=1, found
{
num_models
}
"
logger
.
info
(
f
"Loading
{
identifier
}
model"
)
logger
.
info
(
f
"Loading
{
identifier
}
model"
)
...
@@ -123,6 +124,8 @@ def load_all_models(
...
@@ -123,6 +124,8 @@ def load_all_models(
`model`: loaded model
`model`: loaded model
`rank`: rank of model
`rank`: rank of model
"""
"""
from
nndet.ptmodule
import
MODULE_REGISTRY
model_names
=
list
(
source_models
.
glob
(
'*.ckpt'
))
model_names
=
list
(
source_models
.
glob
(
'*.ckpt'
))
if
not
model_names
:
if
not
model_names
:
raise
RuntimeError
(
f
"Did not find any models in
{
source_models
}
"
)
raise
RuntimeError
(
f
"Did not find any models in
{
source_models
}
"
)
...
...
scripts/train.py
View file @
0cdad98d
...
@@ -18,6 +18,7 @@ import os
...
@@ -18,6 +18,7 @@ import os
import
sys
import
sys
import
socket
import
socket
import
argparse
import
argparse
import
importlib
from
pathlib
import
Path
from
pathlib
import
Path
from
datetime
import
datetime
from
datetime
import
datetime
from
typing
import
List
from
typing
import
List
...
@@ -347,6 +348,10 @@ def _sweep(
...
@@ -347,6 +348,10 @@ def _sweep(
cfg
=
OmegaConf
.
load
(
str
(
train_dir
/
"config.yaml"
))
cfg
=
OmegaConf
.
load
(
str
(
train_dir
/
"config.yaml"
))
os
.
chdir
(
str
(
train_dir
))
os
.
chdir
(
str
(
train_dir
))
for
imp
in
cfg
.
get
(
"additional_imports"
,
[]):
print
(
f
"Additional import found
{
imp
}
"
)
importlib
.
import_module
(
imp
)
logger
.
remove
()
logger
.
remove
()
logger
.
add
(
sys
.
stdout
,
format
=
"{level} {message}"
,
level
=
"INFO"
)
logger
.
add
(
sys
.
stdout
,
format
=
"{level} {message}"
,
level
=
"INFO"
)
log_file
=
Path
(
os
.
getcwd
())
/
"sweep.log"
log_file
=
Path
(
os
.
getcwd
())
/
"sweep.log"
...
...
scripts/utils.py
View file @
0cdad98d
...
@@ -186,6 +186,19 @@ def unpack():
...
@@ -186,6 +186,19 @@ def unpack():
unpack_dataset
(
p
,
num_processes
,
False
)
unpack_dataset
(
p
,
num_processes
,
False
)
def
hydra_searchpath
():
from
hydra
import
compose
as
hydra_compose
from
hydra
import
initialize_config_module
initialize_config_module
(
config_module
=
"nndet.conf"
)
cfg
=
hydra_compose
(
"config.yaml"
,
return_hydra_config
=
True
)
print
(
"Found config sources::"
)
print
(
"----------------------"
)
for
s
in
cfg
.
hydra
.
runtime
.
config_sources
:
print
(
s
)
def
env
():
def
env
():
import
os
import
os
import
torch
import
torch
...
...
setup.py
View file @
0cdad98d
...
@@ -129,6 +129,7 @@ setup(
...
@@ -129,6 +129,7 @@ setup(
'nndet_seg2nii = scripts.utils:seg2nii'
,
'nndet_seg2nii = scripts.utils:seg2nii'
,
'nndet_unpack = scripts.utils:unpack'
,
'nndet_unpack = scripts.utils:unpack'
,
'nndet_env = scripts.utils:env'
,
'nndet_env = scripts.utils:env'
,
'nndet_searchpath = scripts.utils:hydra_searchpath'
]
]
},
},
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment