Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
97a58f31
Unverified
Commit
97a58f31
authored
Oct 27, 2025
by
Andres Martinez
Committed by
GitHub
Oct 27, 2025
Browse files
Merge pull request #333 from MIC-DKFZ/origin/0003_noPrintout_inference
Remove model inference summary printout
parents
719016d3
3446e932
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
143 additions
and
107 deletions
+143
-107
nndet/ptmodule/retinaunet/base.py
nndet/ptmodule/retinaunet/base.py
+143
-107
No files found.
nndet/ptmodule/retinaunet/base.py
View file @
97a58f31
...
...
@@ -68,7 +68,7 @@ from nndet.io.transforms import (
Instances2Boxes
,
Instances2Segmentation
,
FindInstances
,
)
)
class
RetinaUNetModule
(
LightningBaseModuleSWA
):
...
...
@@ -84,12 +84,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
head_sampler_cls
=
HardNegativeSamplerBatched
segmenter_cls
=
DiCESegmenter
def
__init__
(
self
,
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
def
__init__
(
self
,
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
"""
RetinaUNet Lightning Module Skeleton
...
...
@@ -106,7 +101,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
plan
=
plan
,
)
_classes
=
[
f
"class
{
c
}
"
for
c
in
range
(
plan
[
"architecture"
][
"classifier_classes"
])]
_classes
=
[
f
"class
{
c
}
"
for
c
in
range
(
plan
[
"architecture"
][
"classifier_classes"
])
]
self
.
box_evaluator
=
BoxEvaluator
.
create
(
classes
=
_classes
,
fast
=
True
,
...
...
@@ -130,7 +127,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
instance_key
=
"target"
,
map_key
=
"instance_mapping"
,
present_instances
=
"present_instances"
,
)
)
,
)
self
.
eval_score_key
=
"mAP_IoU_0.10_0.50_0.05_MaxDet_100"
...
...
@@ -148,7 +145,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets
=
{
"target_boxes"
:
batch
[
"boxes"
],
"target_classes"
:
batch
[
"classes"
],
"target_seg"
:
batch
[
'
target
'
][:,
0
]
# Remove channel dimension
"target_seg"
:
batch
[
"
target
"
][:,
0
]
,
# Remove channel dimension
},
evaluation
=
False
,
batch_num
=
batch_idx
,
...
...
@@ -167,7 +164,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets
=
{
"target_boxes"
:
batch
[
"boxes"
],
"target_classes"
:
batch
[
"classes"
],
"target_seg"
:
batch
[
'
target
'
][:,
0
]
# Remove channel dimension
"target_seg"
:
batch
[
"
target
"
][:,
0
]
,
# Remove channel dimension
}
losses
,
prediction
=
self
.
model
.
train_step
(
images
=
batch
[
"data"
],
...
...
@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
loss
=
sum
(
losses
.
values
())
self
.
evaluation_step
(
prediction
=
prediction
,
targets
=
targets
)
return
{
"loss"
:
loss
.
detach
().
item
(),
**
{
key
:
l
.
detach
().
item
()
for
key
,
l
in
losses
.
items
()}}
return
{
"loss"
:
loss
.
detach
().
item
(),
**
{
key
:
l
.
detach
().
item
()
for
key
,
l
in
losses
.
items
()},
}
def
evaluation_step
(
self
,
...
...
@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA):
metric_scores
,
_
=
self
.
box_evaluator
.
finish_online_evaluation
()
self
.
box_evaluator
.
reset
()
logger
.
info
(
f
"mAP@0.1:0.5:0.05:
{
metric_scores
[
'mAP_IoU_0.10_0.50_0.05_MaxDet_100'
]:
0.3
f
}
"
logger
.
info
(
f
"mAP@0.1:0.5:0.05:
{
metric_scores
[
'mAP_IoU_0.10_0.50_0.05_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.1:
{
metric_scores
[
'AP_IoU_0.10_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.5:
{
metric_scores
[
'AP_IoU_0.50_MaxDet_100'
]:
0.3
f
}
"
)
f
"AP@0.5:
{
metric_scores
[
'AP_IoU_0.50_MaxDet_100'
]:
0.3
f
}
"
)
seg_scores
,
_
=
self
.
seg_evaluator
.
finish_online_evaluation
()
self
.
seg_evaluator
.
reset
()
...
...
@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
logger
.
info
(
f
"Proxy FG Dice:
{
seg_scores
[
'seg_dice'
]:
0.3
f
}
"
)
for
key
,
item
in
metric_scores
.
items
():
self
.
log
(
f
'
{
key
}
'
,
item
,
on_step
=
None
,
on_epoch
=
True
,
prog_bar
=
False
,
logger
=
True
)
self
.
log
(
f
"
{
key
}
"
,
item
,
on_step
=
None
,
on_epoch
=
True
,
prog_bar
=
False
,
logger
=
True
)
def
configure_optimizers
(
self
):
"""
...
...
@@ -301,11 +304,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
schedule
"""
# configure optimizer
logger
.
info
(
f
"Running: initial_lr
{
self
.
trainer_cfg
[
'initial_lr'
]
}
"
logger
.
info
(
f
"Running: initial_lr
{
self
.
trainer_cfg
[
'initial_lr'
]
}
"
f
"weight_decay
{
self
.
trainer_cfg
[
'weight_decay'
]
}
"
f
"SGD with momentum
{
self
.
trainer_cfg
[
'sgd_momentum'
]
}
and "
f
"nesterov
{
self
.
trainer_cfg
[
'sgd_nesterov'
]
}
"
)
wd_groups
=
get_params_no_wd_on_norm
(
self
,
weight_decay
=
self
.
trainer_cfg
[
'weight_decay'
])
f
"nesterov
{
self
.
trainer_cfg
[
'sgd_nesterov'
]
}
"
)
wd_groups
=
get_params_no_wd_on_norm
(
self
,
weight_decay
=
self
.
trainer_cfg
[
"weight_decay"
]
)
optimizer
=
torch
.
optim
.
SGD
(
wd_groups
,
self
.
trainer_cfg
[
"initial_lr"
],
...
...
@@ -315,19 +322,22 @@ class RetinaUNetModule(LightningBaseModuleSWA):
)
# configure lr scheduler
num_iterations
=
self
.
trainer_cfg
[
"max_num_epochs"
]
*
\
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
]
num_iterations
=
(
self
.
trainer_cfg
[
"max_num_epochs"
]
*
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
]
)
scheduler
=
LinearWarmupPolyLR
(
optimizer
=
optimizer
,
warm_iterations
=
self
.
trainer_cfg
[
"warm_iterations"
],
warm_lr
=
self
.
trainer_cfg
[
"warm_lr"
],
poly_gamma
=
self
.
trainer_cfg
[
"poly_gamma"
],
num_iterations
=
num_iterations
num_iterations
=
num_iterations
,
)
return
[
optimizer
],
{
'
scheduler
'
:
scheduler
,
'
interval
'
:
'
step
'
}
return
[
optimizer
],
{
"
scheduler
"
:
scheduler
,
"
interval
"
:
"
step
"
}
@
classmethod
def
from_config_plan
(
cls
,
def
from_config_plan
(
cls
,
model_cfg
:
dict
,
plan_arch
:
dict
,
plan_anchors
:
dict
,
...
...
@@ -359,21 +369,32 @@ class RetinaUNetModule(LightningBaseModuleSWA):
will be performed
**kwargs:
"""
logger
.
info
(
f
"Architecture overwrites:
{
model_cfg
[
'plan_arch_overwrites'
]
}
"
f
"Anchor overwrites:
{
model_cfg
[
'plan_anchors_overwrites'
]
}
"
)
logger
.
info
(
f
"Building architecture according to plan of
{
plan_arch
.
get
(
'arch_name'
,
'not_found'
)
}
"
)
logger
.
info
(
f
"Architecture overwrites:
{
model_cfg
[
'plan_arch_overwrites'
]
}
"
f
"Anchor overwrites:
{
model_cfg
[
'plan_anchors_overwrites'
]
}
"
)
logger
.
info
(
f
"Building architecture according to plan of
{
plan_arch
.
get
(
'arch_name'
,
'not_found'
)
}
"
)
plan_arch
.
update
(
model_cfg
[
"plan_arch_overwrites"
])
plan_anchors
.
update
(
model_cfg
[
"plan_anchors_overwrites"
])
logger
.
info
(
f
"Start channels:
{
plan_arch
[
'start_channels'
]
}
; "
logger
.
info
(
f
"Start channels:
{
plan_arch
[
'start_channels'
]
}
; "
f
"head channels:
{
plan_arch
[
'head_channels'
]
}
; "
f
"fpn channels:
{
plan_arch
[
'fpn_channels'
]
}
"
)
f
"fpn channels:
{
plan_arch
[
'fpn_channels'
]
}
"
)
_plan_anchors
=
copy
.
deepcopy
(
plan_anchors
)
coder
=
BoxCoderND
(
weights
=
(
1.
,)
*
(
plan_arch
[
"dim"
]
*
2
))
s_param
=
False
if
(
"aspect_ratios"
in
_plan_anchors
)
and
\
(
_plan_anchors
[
"aspect_ratios"
]
is
not
None
)
else
True
anchor_generator
=
get_anchor_generator
(
plan_arch
[
"dim"
],
s_param
=
s_param
)(
**
_plan_anchors
)
coder
=
BoxCoderND
(
weights
=
(
1.0
,)
*
(
plan_arch
[
"dim"
]
*
2
))
s_param
=
(
False
if
(
"aspect_ratios"
in
_plan_anchors
)
and
(
_plan_anchors
[
"aspect_ratios"
]
is
not
None
)
else
True
)
anchor_generator
=
get_anchor_generator
(
plan_arch
[
"dim"
],
s_param
=
s_param
)(
**
_plan_anchors
)
encoder
=
cls
.
_build_encoder
(
plan_arch
=
plan_arch
,
...
...
@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
model_cfg
=
model_cfg
,
classifier
=
classifier
,
regressor
=
regressor
,
coder
=
coder
coder
=
coder
,
)
segmenter
=
cls
.
_build_segmenter
(
plan_arch
=
plan_arch
,
...
...
@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA):
remove_small_boxes
=
plan_arch
.
get
(
"remove_small_boxes"
,
0.01
)
nms_thresh
=
plan_arch
.
get
(
"nms_thresh"
,
0.6
)
logger
.
info
(
f
"Model Inference Summary:
\n
"
f
"detections_per_img:
{
detections_per_img
}
\n
"
f
"score_thresh:
{
score_thresh
}
\n
"
f
"topk_candidates:
{
topk_candidates
}
\n
"
f
"remove_small_boxes:
{
remove_small_boxes
}
\n
"
f
"nms_thresh:
{
nms_thresh
}
"
,
)
#
logger.info(f"Model Inference Summary: \n"
#
f"detections_per_img: {detections_per_img} \n"
#
f"score_thresh: {score_thresh} \n"
#
f"topk_candidates: {topk_candidates} \n"
#
f"remove_small_boxes: {remove_small_boxes} \n"
#
f"nms_thresh: {nms_thresh}",
#
)
return
BaseRetinaNet
(
dim
=
plan_arch
[
"dim"
],
...
...
@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
EncoderType: encoder instance
"""
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: encoder
{
cls
.
encoder_cls
.
__name__
}
:
{
model_cfg
[
'encoder_kwargs'
]
}
"
)
logger
.
info
(
f
"Building:: encoder
{
cls
.
encoder_cls
.
__name__
}
:
{
model_cfg
[
'encoder_kwargs'
]
}
"
)
encoder
=
cls
.
encoder_cls
(
conv
=
conv
,
conv_kernels
=
plan_arch
[
"conv_kernels"
],
...
...
@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
start_channels
=
plan_arch
[
"start_channels"
],
stage_kwargs
=
None
,
max_channels
=
plan_arch
.
get
(
"max_channels"
,
320
),
**
model_cfg
[
'
encoder_kwargs
'
],
**
model_cfg
[
"
encoder_kwargs
"
],
)
return
encoder
...
...
@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
DecoderType: decoder instance
"""
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: decoder
{
cls
.
decoder_cls
.
__name__
}
:
{
model_cfg
[
'decoder_kwargs'
]
}
"
)
logger
.
info
(
f
"Building:: decoder
{
cls
.
decoder_cls
.
__name__
}
:
{
model_cfg
[
'decoder_kwargs'
]
}
"
)
decoder
=
cls
.
decoder_cls
(
conv
=
conv
,
conv_kernels
=
plan_arch
[
"conv_kernels"
],
...
...
@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
in_channels
=
encoder
.
get_channels
(),
decoder_levels
=
plan_arch
[
"decoder_levels"
],
fixed_out_channels
=
plan_arch
[
"fpn_channels"
],
**
model_cfg
[
'
decoder_kwargs
'
],
**
model_cfg
[
"
decoder_kwargs
"
],
)
return
decoder
...
...
@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
name
=
cls
.
head_classifier_cls
.
__name__
kwargs
=
model_cfg
[
'
head_classifier_kwargs
'
]
kwargs
=
model_cfg
[
"
head_classifier_kwargs
"
]
logger
.
info
(
f
"Building:: classifier
{
name
}
:
{
kwargs
}
"
)
classifier
=
cls
.
head_classifier_cls
(
...
...
@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
name
=
cls
.
head_regressor_cls
.
__name__
kwargs
=
model_cfg
[
'
head_regressor_kwargs
'
]
kwargs
=
model_cfg
[
"
head_regressor_kwargs
"
]
logger
.
info
(
f
"Building:: regressor
{
name
}
:
{
kwargs
}
"
)
regressor
=
cls
.
head_regressor_cls
(
...
...
@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA):
HeadType: instantiated head
"""
head_name
=
cls
.
head_cls
.
__name__
head_kwargs
=
model_cfg
[
'
head_kwargs
'
]
head_kwargs
=
model_cfg
[
"
head_kwargs
"
]
sampler_name
=
cls
.
head_sampler_cls
.
__name__
sampler_kwargs
=
model_cfg
[
'
head_sampler_kwargs
'
]
sampler_kwargs
=
model_cfg
[
"
head_sampler_kwargs
"
]
logger
.
info
(
f
"Building:: head
{
head_name
}
:
{
head_kwargs
}
"
f
"sampler
{
sampler_name
}
:
{
sampler_kwargs
}
"
)
logger
.
info
(
f
"Building:: head
{
head_name
}
:
{
head_kwargs
}
"
f
"sampler
{
sampler_name
}
:
{
sampler_kwargs
}
"
)
sampler
=
cls
.
head_sampler_cls
(
**
sampler_kwargs
)
head
=
cls
.
head_cls
(
classifier
=
classifier
,
...
...
@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
if
cls
.
segmenter_cls
is
not
None
:
name
=
cls
.
segmenter_cls
.
__name__
kwargs
=
model_cfg
[
'
segmenter_kwargs
'
]
kwargs
=
model_cfg
[
"
segmenter_kwargs
"
]
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: segmenter
{
name
}
{
kwargs
}
"
)
...
...
@@ -661,14 +688,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
3
:
{
"boxes"
:
BoxEnsemblerSelective
,
"seg"
:
SegmentationEnsembler
,
}
}
,
}
if
dim
==
2
:
raise
NotImplementedError
return
_lookup
[
dim
][
key
]
@
classmethod
def
get_predictor
(
cls
,
def
get_predictor
(
cls
,
plan
:
Dict
,
models
:
Sequence
[
RetinaUNetModule
],
num_tta_transforms
:
int
=
None
,
...
...
@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA):
num_tta_transforms
=
8
if
plan
[
"network_dim"
]
==
3
else
4
# setup
tta_transforms
,
tta_inverse_transforms
=
\
get_tta_transforms
(
num_tta_transforms
,
True
)
logger
.
info
(
f
"Using
{
len
(
tta_transforms
)
}
tta transformations for prediction (one dummy trafo)."
)
tta_transforms
,
tta_inverse_transforms
=
get_tta_transforms
(
num_tta_transforms
,
True
)
logger
.
info
(
f
"Using
{
len
(
tta_transforms
)
}
tta transformations for prediction (one dummy trafo)."
)
ensembler
=
{
"boxes"
:
partial
(
ensembler
=
{
"boxes"
:
partial
(
cls
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
parameters
=
inferene_plan
,
)}
)
}
if
do_seg
:
ensembler
[
"seg"
]
=
partial
(
cls
.
get_ensembler_cls
(
key
=
"seg"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
...
...
@@ -711,7 +744,8 @@ class RetinaUNetModule(LightningBaseModuleSWA):
predictor
.
pre_transform
=
Inference2D
([
"data"
])
return
predictor
def
sweep
(
self
,
def
sweep
(
self
,
cfg
:
dict
,
save_dir
:
os
.
PathLike
,
train_data_dir
:
os
.
PathLike
,
...
...
@@ -767,7 +801,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
)
logger
.
info
(
"Start parameter sweep..."
)
ensembler_cls
=
self
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
self
.
plan
[
"network_dim"
])
ensembler_cls
=
self
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
self
.
plan
[
"network_dim"
]
)
sweeper
=
BoxSweeper
(
classes
=
[
item
for
_
,
item
in
cfg
[
"data"
][
"labels"
].
items
()],
pred_dir
=
prediction_dir
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment