Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
3446e932
Commit
3446e932
authored
Oct 27, 2025
by
Andres Martinez Mora
Browse files
Remove model inference summary printout
parent
e708c342
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
143 additions
and
107 deletions
+143
-107
nndet/ptmodule/retinaunet/base.py
nndet/ptmodule/retinaunet/base.py
+143
-107
No files found.
nndet/ptmodule/retinaunet/base.py
View file @
3446e932
...
...
@@ -68,7 +68,7 @@ from nndet.io.transforms import (
Instances2Boxes
,
Instances2Segmentation
,
FindInstances
,
)
)
class
RetinaUNetModule
(
LightningBaseModuleSWA
):
...
...
@@ -84,12 +84,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
head_sampler_cls
=
HardNegativeSamplerBatched
segmenter_cls
=
DiCESegmenter
def
__init__
(
self
,
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
def
__init__
(
self
,
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
"""
RetinaUNet Lightning Module Skeleton
...
...
@@ -106,7 +101,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
plan
=
plan
,
)
_classes
=
[
f
"class
{
c
}
"
for
c
in
range
(
plan
[
"architecture"
][
"classifier_classes"
])]
_classes
=
[
f
"class
{
c
}
"
for
c
in
range
(
plan
[
"architecture"
][
"classifier_classes"
])
]
self
.
box_evaluator
=
BoxEvaluator
.
create
(
classes
=
_classes
,
fast
=
True
,
...
...
@@ -130,7 +127,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
instance_key
=
"target"
,
map_key
=
"instance_mapping"
,
present_instances
=
"present_instances"
,
)
)
,
)
self
.
eval_score_key
=
"mAP_IoU_0.10_0.50_0.05_MaxDet_100"
...
...
@@ -148,7 +145,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets
=
{
"target_boxes"
:
batch
[
"boxes"
],
"target_classes"
:
batch
[
"classes"
],
"target_seg"
:
batch
[
'
target
'
][:,
0
]
# Remove channel dimension
"target_seg"
:
batch
[
"
target
"
][:,
0
]
,
# Remove channel dimension
},
evaluation
=
False
,
batch_num
=
batch_idx
,
...
...
@@ -167,7 +164,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets
=
{
"target_boxes"
:
batch
[
"boxes"
],
"target_classes"
:
batch
[
"classes"
],
"target_seg"
:
batch
[
'
target
'
][:,
0
]
# Remove channel dimension
"target_seg"
:
batch
[
"
target
"
][:,
0
]
,
# Remove channel dimension
}
losses
,
prediction
=
self
.
model
.
train_step
(
images
=
batch
[
"data"
],
...
...
@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
loss
=
sum
(
losses
.
values
())
self
.
evaluation_step
(
prediction
=
prediction
,
targets
=
targets
)
return
{
"loss"
:
loss
.
detach
().
item
(),
**
{
key
:
l
.
detach
().
item
()
for
key
,
l
in
losses
.
items
()}}
return
{
"loss"
:
loss
.
detach
().
item
(),
**
{
key
:
l
.
detach
().
item
()
for
key
,
l
in
losses
.
items
()},
}
def
evaluation_step
(
self
,
...
...
@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA):
metric_scores
,
_
=
self
.
box_evaluator
.
finish_online_evaluation
()
self
.
box_evaluator
.
reset
()
logger
.
info
(
f
"mAP@0.1:0.5:0.05:
{
metric_scores
[
'mAP_IoU_0.10_0.50_0.05_MaxDet_100'
]:
0.3
f
}
"
logger
.
info
(
f
"mAP@0.1:0.5:0.05:
{
metric_scores
[
'mAP_IoU_0.10_0.50_0.05_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.1:
{
metric_scores
[
'AP_IoU_0.10_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.5:
{
metric_scores
[
'AP_IoU_0.50_MaxDet_100'
]:
0.3
f
}
"
)
f
"AP@0.5:
{
metric_scores
[
'AP_IoU_0.50_MaxDet_100'
]:
0.3
f
}
"
)
seg_scores
,
_
=
self
.
seg_evaluator
.
finish_online_evaluation
()
self
.
seg_evaluator
.
reset
()
...
...
@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
logger
.
info
(
f
"Proxy FG Dice:
{
seg_scores
[
'seg_dice'
]:
0.3
f
}
"
)
for
key
,
item
in
metric_scores
.
items
():
self
.
log
(
f
'
{
key
}
'
,
item
,
on_step
=
None
,
on_epoch
=
True
,
prog_bar
=
False
,
logger
=
True
)
self
.
log
(
f
"
{
key
}
"
,
item
,
on_step
=
None
,
on_epoch
=
True
,
prog_bar
=
False
,
logger
=
True
)
def
configure_optimizers
(
self
):
"""
...
...
@@ -301,11 +304,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
schedule
"""
# configure optimizer
logger
.
info
(
f
"Running: initial_lr
{
self
.
trainer_cfg
[
'initial_lr'
]
}
"
logger
.
info
(
f
"Running: initial_lr
{
self
.
trainer_cfg
[
'initial_lr'
]
}
"
f
"weight_decay
{
self
.
trainer_cfg
[
'weight_decay'
]
}
"
f
"SGD with momentum
{
self
.
trainer_cfg
[
'sgd_momentum'
]
}
and "
f
"nesterov
{
self
.
trainer_cfg
[
'sgd_nesterov'
]
}
"
)
wd_groups
=
get_params_no_wd_on_norm
(
self
,
weight_decay
=
self
.
trainer_cfg
[
'weight_decay'
])
f
"nesterov
{
self
.
trainer_cfg
[
'sgd_nesterov'
]
}
"
)
wd_groups
=
get_params_no_wd_on_norm
(
self
,
weight_decay
=
self
.
trainer_cfg
[
"weight_decay"
]
)
optimizer
=
torch
.
optim
.
SGD
(
wd_groups
,
self
.
trainer_cfg
[
"initial_lr"
],
...
...
@@ -315,19 +322,22 @@ class RetinaUNetModule(LightningBaseModuleSWA):
)
# configure lr scheduler
num_iterations
=
self
.
trainer_cfg
[
"max_num_epochs"
]
*
\
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
]
num_iterations
=
(
self
.
trainer_cfg
[
"max_num_epochs"
]
*
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
]
)
scheduler
=
LinearWarmupPolyLR
(
optimizer
=
optimizer
,
warm_iterations
=
self
.
trainer_cfg
[
"warm_iterations"
],
warm_lr
=
self
.
trainer_cfg
[
"warm_lr"
],
poly_gamma
=
self
.
trainer_cfg
[
"poly_gamma"
],
num_iterations
=
num_iterations
num_iterations
=
num_iterations
,
)
return
[
optimizer
],
{
'
scheduler
'
:
scheduler
,
'
interval
'
:
'
step
'
}
return
[
optimizer
],
{
"
scheduler
"
:
scheduler
,
"
interval
"
:
"
step
"
}
@
classmethod
def
from_config_plan
(
cls
,
def
from_config_plan
(
cls
,
model_cfg
:
dict
,
plan_arch
:
dict
,
plan_anchors
:
dict
,
...
...
@@ -359,21 +369,32 @@ class RetinaUNetModule(LightningBaseModuleSWA):
will be performed
**kwargs:
"""
logger
.
info
(
f
"Architecture overwrites:
{
model_cfg
[
'plan_arch_overwrites'
]
}
"
f
"Anchor overwrites:
{
model_cfg
[
'plan_anchors_overwrites'
]
}
"
)
logger
.
info
(
f
"Building architecture according to plan of
{
plan_arch
.
get
(
'arch_name'
,
'not_found'
)
}
"
)
logger
.
info
(
f
"Architecture overwrites:
{
model_cfg
[
'plan_arch_overwrites'
]
}
"
f
"Anchor overwrites:
{
model_cfg
[
'plan_anchors_overwrites'
]
}
"
)
logger
.
info
(
f
"Building architecture according to plan of
{
plan_arch
.
get
(
'arch_name'
,
'not_found'
)
}
"
)
plan_arch
.
update
(
model_cfg
[
"plan_arch_overwrites"
])
plan_anchors
.
update
(
model_cfg
[
"plan_anchors_overwrites"
])
logger
.
info
(
f
"Start channels:
{
plan_arch
[
'start_channels'
]
}
; "
logger
.
info
(
f
"Start channels:
{
plan_arch
[
'start_channels'
]
}
; "
f
"head channels:
{
plan_arch
[
'head_channels'
]
}
; "
f
"fpn channels:
{
plan_arch
[
'fpn_channels'
]
}
"
)
f
"fpn channels:
{
plan_arch
[
'fpn_channels'
]
}
"
)
_plan_anchors
=
copy
.
deepcopy
(
plan_anchors
)
coder
=
BoxCoderND
(
weights
=
(
1.
,)
*
(
plan_arch
[
"dim"
]
*
2
))
s_param
=
False
if
(
"aspect_ratios"
in
_plan_anchors
)
and
\
(
_plan_anchors
[
"aspect_ratios"
]
is
not
None
)
else
True
anchor_generator
=
get_anchor_generator
(
plan_arch
[
"dim"
],
s_param
=
s_param
)(
**
_plan_anchors
)
coder
=
BoxCoderND
(
weights
=
(
1.0
,)
*
(
plan_arch
[
"dim"
]
*
2
))
s_param
=
(
False
if
(
"aspect_ratios"
in
_plan_anchors
)
and
(
_plan_anchors
[
"aspect_ratios"
]
is
not
None
)
else
True
)
anchor_generator
=
get_anchor_generator
(
plan_arch
[
"dim"
],
s_param
=
s_param
)(
**
_plan_anchors
)
encoder
=
cls
.
_build_encoder
(
plan_arch
=
plan_arch
,
...
...
@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
model_cfg
=
model_cfg
,
classifier
=
classifier
,
regressor
=
regressor
,
coder
=
coder
coder
=
coder
,
)
segmenter
=
cls
.
_build_segmenter
(
plan_arch
=
plan_arch
,
...
...
@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA):
remove_small_boxes
=
plan_arch
.
get
(
"remove_small_boxes"
,
0.01
)
nms_thresh
=
plan_arch
.
get
(
"nms_thresh"
,
0.6
)
logger
.
info
(
f
"Model Inference Summary:
\n
"
f
"detections_per_img:
{
detections_per_img
}
\n
"
f
"score_thresh:
{
score_thresh
}
\n
"
f
"topk_candidates:
{
topk_candidates
}
\n
"
f
"remove_small_boxes:
{
remove_small_boxes
}
\n
"
f
"nms_thresh:
{
nms_thresh
}
"
,
)
#
logger.info(f"Model Inference Summary: \n"
#
f"detections_per_img: {detections_per_img} \n"
#
f"score_thresh: {score_thresh} \n"
#
f"topk_candidates: {topk_candidates} \n"
#
f"remove_small_boxes: {remove_small_boxes} \n"
#
f"nms_thresh: {nms_thresh}",
#
)
return
BaseRetinaNet
(
dim
=
plan_arch
[
"dim"
],
...
...
@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
EncoderType: encoder instance
"""
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: encoder
{
cls
.
encoder_cls
.
__name__
}
:
{
model_cfg
[
'encoder_kwargs'
]
}
"
)
logger
.
info
(
f
"Building:: encoder
{
cls
.
encoder_cls
.
__name__
}
:
{
model_cfg
[
'encoder_kwargs'
]
}
"
)
encoder
=
cls
.
encoder_cls
(
conv
=
conv
,
conv_kernels
=
plan_arch
[
"conv_kernels"
],
...
...
@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
start_channels
=
plan_arch
[
"start_channels"
],
stage_kwargs
=
None
,
max_channels
=
plan_arch
.
get
(
"max_channels"
,
320
),
**
model_cfg
[
'
encoder_kwargs
'
],
**
model_cfg
[
"
encoder_kwargs
"
],
)
return
encoder
...
...
@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
DecoderType: decoder instance
"""
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: decoder
{
cls
.
decoder_cls
.
__name__
}
:
{
model_cfg
[
'decoder_kwargs'
]
}
"
)
logger
.
info
(
f
"Building:: decoder
{
cls
.
decoder_cls
.
__name__
}
:
{
model_cfg
[
'decoder_kwargs'
]
}
"
)
decoder
=
cls
.
decoder_cls
(
conv
=
conv
,
conv_kernels
=
plan_arch
[
"conv_kernels"
],
...
...
@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
in_channels
=
encoder
.
get_channels
(),
decoder_levels
=
plan_arch
[
"decoder_levels"
],
fixed_out_channels
=
plan_arch
[
"fpn_channels"
],
**
model_cfg
[
'
decoder_kwargs
'
],
**
model_cfg
[
"
decoder_kwargs
"
],
)
return
decoder
...
...
@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
name
=
cls
.
head_classifier_cls
.
__name__
kwargs
=
model_cfg
[
'
head_classifier_kwargs
'
]
kwargs
=
model_cfg
[
"
head_classifier_kwargs
"
]
logger
.
info
(
f
"Building:: classifier
{
name
}
:
{
kwargs
}
"
)
classifier
=
cls
.
head_classifier_cls
(
...
...
@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
name
=
cls
.
head_regressor_cls
.
__name__
kwargs
=
model_cfg
[
'
head_regressor_kwargs
'
]
kwargs
=
model_cfg
[
"
head_regressor_kwargs
"
]
logger
.
info
(
f
"Building:: regressor
{
name
}
:
{
kwargs
}
"
)
regressor
=
cls
.
head_regressor_cls
(
...
...
@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA):
HeadType: instantiated head
"""
head_name
=
cls
.
head_cls
.
__name__
head_kwargs
=
model_cfg
[
'
head_kwargs
'
]
head_kwargs
=
model_cfg
[
"
head_kwargs
"
]
sampler_name
=
cls
.
head_sampler_cls
.
__name__
sampler_kwargs
=
model_cfg
[
'
head_sampler_kwargs
'
]
sampler_kwargs
=
model_cfg
[
"
head_sampler_kwargs
"
]
logger
.
info
(
f
"Building:: head
{
head_name
}
:
{
head_kwargs
}
"
f
"sampler
{
sampler_name
}
:
{
sampler_kwargs
}
"
)
logger
.
info
(
f
"Building:: head
{
head_name
}
:
{
head_kwargs
}
"
f
"sampler
{
sampler_name
}
:
{
sampler_kwargs
}
"
)
sampler
=
cls
.
head_sampler_cls
(
**
sampler_kwargs
)
head
=
cls
.
head_cls
(
classifier
=
classifier
,
...
...
@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
if
cls
.
segmenter_cls
is
not
None
:
name
=
cls
.
segmenter_cls
.
__name__
kwargs
=
model_cfg
[
'
segmenter_kwargs
'
]
kwargs
=
model_cfg
[
"
segmenter_kwargs
"
]
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: segmenter
{
name
}
{
kwargs
}
"
)
...
...
@@ -661,14 +688,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
3
:
{
"boxes"
:
BoxEnsemblerSelective
,
"seg"
:
SegmentationEnsembler
,
}
}
,
}
if
dim
==
2
:
raise
NotImplementedError
return
_lookup
[
dim
][
key
]
@
classmethod
def
get_predictor
(
cls
,
def
get_predictor
(
cls
,
plan
:
Dict
,
models
:
Sequence
[
RetinaUNetModule
],
num_tta_transforms
:
int
=
None
,
...
...
@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA):
num_tta_transforms
=
8
if
plan
[
"network_dim"
]
==
3
else
4
# setup
tta_transforms
,
tta_inverse_transforms
=
\
get_tta_transforms
(
num_tta_transforms
,
True
)
logger
.
info
(
f
"Using
{
len
(
tta_transforms
)
}
tta transformations for prediction (one dummy trafo)."
)
tta_transforms
,
tta_inverse_transforms
=
get_tta_transforms
(
num_tta_transforms
,
True
)
logger
.
info
(
f
"Using
{
len
(
tta_transforms
)
}
tta transformations for prediction (one dummy trafo)."
)
ensembler
=
{
"boxes"
:
partial
(
ensembler
=
{
"boxes"
:
partial
(
cls
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
parameters
=
inferene_plan
,
)}
)
}
if
do_seg
:
ensembler
[
"seg"
]
=
partial
(
cls
.
get_ensembler_cls
(
key
=
"seg"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
...
...
@@ -711,7 +744,8 @@ class RetinaUNetModule(LightningBaseModuleSWA):
predictor
.
pre_transform
=
Inference2D
([
"data"
])
return
predictor
def
sweep
(
self
,
def
sweep
(
self
,
cfg
:
dict
,
save_dir
:
os
.
PathLike
,
train_data_dir
:
os
.
PathLike
,
...
...
@@ -767,7 +801,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
)
logger
.
info
(
"Start parameter sweep..."
)
ensembler_cls
=
self
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
self
.
plan
[
"network_dim"
])
ensembler_cls
=
self
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
self
.
plan
[
"network_dim"
]
)
sweeper
=
BoxSweeper
(
classes
=
[
item
for
_
,
item
in
cfg
[
"data"
][
"labels"
].
items
()],
pred_dir
=
prediction_dir
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment