Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
97a58f31
Unverified
Commit
97a58f31
authored
Oct 27, 2025
by
Andres Martinez
Committed by
GitHub
Oct 27, 2025
Browse files
Merge pull request #333 from MIC-DKFZ/origin/0003_noPrintout_inference
Remove model inference summary printout
parents
719016d3
3446e932
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
143 additions
and
107 deletions
+143
-107
nndet/ptmodule/retinaunet/base.py
nndet/ptmodule/retinaunet/base.py
+143
-107
No files found.
nndet/ptmodule/retinaunet/base.py
View file @
97a58f31
...
@@ -68,7 +68,7 @@ from nndet.io.transforms import (
...
@@ -68,7 +68,7 @@ from nndet.io.transforms import (
Instances2Boxes
,
Instances2Boxes
,
Instances2Segmentation
,
Instances2Segmentation
,
FindInstances
,
FindInstances
,
)
)
class
RetinaUNetModule
(
LightningBaseModuleSWA
):
class
RetinaUNetModule
(
LightningBaseModuleSWA
):
...
@@ -84,12 +84,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -84,12 +84,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
head_sampler_cls
=
HardNegativeSamplerBatched
head_sampler_cls
=
HardNegativeSamplerBatched
segmenter_cls
=
DiCESegmenter
segmenter_cls
=
DiCESegmenter
def
__init__
(
self
,
def
__init__
(
self
,
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
"""
"""
RetinaUNet Lightning Module Skeleton
RetinaUNet Lightning Module Skeleton
...
@@ -106,7 +101,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -106,7 +101,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
plan
=
plan
,
plan
=
plan
,
)
)
_classes
=
[
f
"class
{
c
}
"
for
c
in
range
(
plan
[
"architecture"
][
"classifier_classes"
])]
_classes
=
[
f
"class
{
c
}
"
for
c
in
range
(
plan
[
"architecture"
][
"classifier_classes"
])
]
self
.
box_evaluator
=
BoxEvaluator
.
create
(
self
.
box_evaluator
=
BoxEvaluator
.
create
(
classes
=
_classes
,
classes
=
_classes
,
fast
=
True
,
fast
=
True
,
...
@@ -130,7 +127,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -130,7 +127,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
instance_key
=
"target"
,
instance_key
=
"target"
,
map_key
=
"instance_mapping"
,
map_key
=
"instance_mapping"
,
present_instances
=
"present_instances"
,
present_instances
=
"present_instances"
,
)
)
,
)
)
self
.
eval_score_key
=
"mAP_IoU_0.10_0.50_0.05_MaxDet_100"
self
.
eval_score_key
=
"mAP_IoU_0.10_0.50_0.05_MaxDet_100"
...
@@ -148,7 +145,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -148,7 +145,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets
=
{
targets
=
{
"target_boxes"
:
batch
[
"boxes"
],
"target_boxes"
:
batch
[
"boxes"
],
"target_classes"
:
batch
[
"classes"
],
"target_classes"
:
batch
[
"classes"
],
"target_seg"
:
batch
[
'
target
'
][:,
0
]
# Remove channel dimension
"target_seg"
:
batch
[
"
target
"
][:,
0
]
,
# Remove channel dimension
},
},
evaluation
=
False
,
evaluation
=
False
,
batch_num
=
batch_idx
,
batch_num
=
batch_idx
,
...
@@ -167,7 +164,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -167,7 +164,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets
=
{
targets
=
{
"target_boxes"
:
batch
[
"boxes"
],
"target_boxes"
:
batch
[
"boxes"
],
"target_classes"
:
batch
[
"classes"
],
"target_classes"
:
batch
[
"classes"
],
"target_seg"
:
batch
[
'
target
'
][:,
0
]
# Remove channel dimension
"target_seg"
:
batch
[
"
target
"
][:,
0
]
,
# Remove channel dimension
}
}
losses
,
prediction
=
self
.
model
.
train_step
(
losses
,
prediction
=
self
.
model
.
train_step
(
images
=
batch
[
"data"
],
images
=
batch
[
"data"
],
...
@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
loss
=
sum
(
losses
.
values
())
loss
=
sum
(
losses
.
values
())
self
.
evaluation_step
(
prediction
=
prediction
,
targets
=
targets
)
self
.
evaluation_step
(
prediction
=
prediction
,
targets
=
targets
)
return
{
"loss"
:
loss
.
detach
().
item
(),
return
{
**
{
key
:
l
.
detach
().
item
()
for
key
,
l
in
losses
.
items
()}}
"loss"
:
loss
.
detach
().
item
(),
**
{
key
:
l
.
detach
().
item
()
for
key
,
l
in
losses
.
items
()},
}
def
evaluation_step
(
def
evaluation_step
(
self
,
self
,
...
@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA):
metric_scores
,
_
=
self
.
box_evaluator
.
finish_online_evaluation
()
metric_scores
,
_
=
self
.
box_evaluator
.
finish_online_evaluation
()
self
.
box_evaluator
.
reset
()
self
.
box_evaluator
.
reset
()
logger
.
info
(
f
"mAP@0.1:0.5:0.05:
{
metric_scores
[
'mAP_IoU_0.10_0.50_0.05_MaxDet_100'
]:
0.3
f
}
"
logger
.
info
(
f
"mAP@0.1:0.5:0.05:
{
metric_scores
[
'mAP_IoU_0.10_0.50_0.05_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.1:
{
metric_scores
[
'AP_IoU_0.10_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.1:
{
metric_scores
[
'AP_IoU_0.10_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.5:
{
metric_scores
[
'AP_IoU_0.50_MaxDet_100'
]:
0.3
f
}
"
)
f
"AP@0.5:
{
metric_scores
[
'AP_IoU_0.50_MaxDet_100'
]:
0.3
f
}
"
)
seg_scores
,
_
=
self
.
seg_evaluator
.
finish_online_evaluation
()
seg_scores
,
_
=
self
.
seg_evaluator
.
finish_online_evaluation
()
self
.
seg_evaluator
.
reset
()
self
.
seg_evaluator
.
reset
()
...
@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
logger
.
info
(
f
"Proxy FG Dice:
{
seg_scores
[
'seg_dice'
]:
0.3
f
}
"
)
logger
.
info
(
f
"Proxy FG Dice:
{
seg_scores
[
'seg_dice'
]:
0.3
f
}
"
)
for
key
,
item
in
metric_scores
.
items
():
for
key
,
item
in
metric_scores
.
items
():
self
.
log
(
f
'
{
key
}
'
,
item
,
on_step
=
None
,
on_epoch
=
True
,
prog_bar
=
False
,
logger
=
True
)
self
.
log
(
f
"
{
key
}
"
,
item
,
on_step
=
None
,
on_epoch
=
True
,
prog_bar
=
False
,
logger
=
True
)
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
"""
"""
...
@@ -301,11 +304,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -301,11 +304,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
schedule
schedule
"""
"""
# configure optimizer
# configure optimizer
logger
.
info
(
f
"Running: initial_lr
{
self
.
trainer_cfg
[
'initial_lr'
]
}
"
logger
.
info
(
f
"Running: initial_lr
{
self
.
trainer_cfg
[
'initial_lr'
]
}
"
f
"weight_decay
{
self
.
trainer_cfg
[
'weight_decay'
]
}
"
f
"weight_decay
{
self
.
trainer_cfg
[
'weight_decay'
]
}
"
f
"SGD with momentum
{
self
.
trainer_cfg
[
'sgd_momentum'
]
}
and "
f
"SGD with momentum
{
self
.
trainer_cfg
[
'sgd_momentum'
]
}
and "
f
"nesterov
{
self
.
trainer_cfg
[
'sgd_nesterov'
]
}
"
)
f
"nesterov
{
self
.
trainer_cfg
[
'sgd_nesterov'
]
}
"
wd_groups
=
get_params_no_wd_on_norm
(
self
,
weight_decay
=
self
.
trainer_cfg
[
'weight_decay'
])
)
wd_groups
=
get_params_no_wd_on_norm
(
self
,
weight_decay
=
self
.
trainer_cfg
[
"weight_decay"
]
)
optimizer
=
torch
.
optim
.
SGD
(
optimizer
=
torch
.
optim
.
SGD
(
wd_groups
,
wd_groups
,
self
.
trainer_cfg
[
"initial_lr"
],
self
.
trainer_cfg
[
"initial_lr"
],
...
@@ -315,19 +322,22 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -315,19 +322,22 @@ class RetinaUNetModule(LightningBaseModuleSWA):
)
)
# configure lr scheduler
# configure lr scheduler
num_iterations
=
self
.
trainer_cfg
[
"max_num_epochs"
]
*
\
num_iterations
=
(
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
]
self
.
trainer_cfg
[
"max_num_epochs"
]
*
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
]
)
scheduler
=
LinearWarmupPolyLR
(
scheduler
=
LinearWarmupPolyLR
(
optimizer
=
optimizer
,
optimizer
=
optimizer
,
warm_iterations
=
self
.
trainer_cfg
[
"warm_iterations"
],
warm_iterations
=
self
.
trainer_cfg
[
"warm_iterations"
],
warm_lr
=
self
.
trainer_cfg
[
"warm_lr"
],
warm_lr
=
self
.
trainer_cfg
[
"warm_lr"
],
poly_gamma
=
self
.
trainer_cfg
[
"poly_gamma"
],
poly_gamma
=
self
.
trainer_cfg
[
"poly_gamma"
],
num_iterations
=
num_iterations
num_iterations
=
num_iterations
,
)
)
return
[
optimizer
],
{
'
scheduler
'
:
scheduler
,
'
interval
'
:
'
step
'
}
return
[
optimizer
],
{
"
scheduler
"
:
scheduler
,
"
interval
"
:
"
step
"
}
@
classmethod
@
classmethod
def
from_config_plan
(
cls
,
def
from_config_plan
(
cls
,
model_cfg
:
dict
,
model_cfg
:
dict
,
plan_arch
:
dict
,
plan_arch
:
dict
,
plan_anchors
:
dict
,
plan_anchors
:
dict
,
...
@@ -359,21 +369,32 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -359,21 +369,32 @@ class RetinaUNetModule(LightningBaseModuleSWA):
will be performed
will be performed
**kwargs:
**kwargs:
"""
"""
logger
.
info
(
f
"Architecture overwrites:
{
model_cfg
[
'plan_arch_overwrites'
]
}
"
logger
.
info
(
f
"Anchor overwrites:
{
model_cfg
[
'plan_anchors_overwrites'
]
}
"
)
f
"Architecture overwrites:
{
model_cfg
[
'plan_arch_overwrites'
]
}
"
logger
.
info
(
f
"Building architecture according to plan of
{
plan_arch
.
get
(
'arch_name'
,
'not_found'
)
}
"
)
f
"Anchor overwrites:
{
model_cfg
[
'plan_anchors_overwrites'
]
}
"
)
logger
.
info
(
f
"Building architecture according to plan of
{
plan_arch
.
get
(
'arch_name'
,
'not_found'
)
}
"
)
plan_arch
.
update
(
model_cfg
[
"plan_arch_overwrites"
])
plan_arch
.
update
(
model_cfg
[
"plan_arch_overwrites"
])
plan_anchors
.
update
(
model_cfg
[
"plan_anchors_overwrites"
])
plan_anchors
.
update
(
model_cfg
[
"plan_anchors_overwrites"
])
logger
.
info
(
f
"Start channels:
{
plan_arch
[
'start_channels'
]
}
; "
logger
.
info
(
f
"Start channels:
{
plan_arch
[
'start_channels'
]
}
; "
f
"head channels:
{
plan_arch
[
'head_channels'
]
}
; "
f
"head channels:
{
plan_arch
[
'head_channels'
]
}
; "
f
"fpn channels:
{
plan_arch
[
'fpn_channels'
]
}
"
)
f
"fpn channels:
{
plan_arch
[
'fpn_channels'
]
}
"
)
_plan_anchors
=
copy
.
deepcopy
(
plan_anchors
)
_plan_anchors
=
copy
.
deepcopy
(
plan_anchors
)
coder
=
BoxCoderND
(
weights
=
(
1.
,)
*
(
plan_arch
[
"dim"
]
*
2
))
coder
=
BoxCoderND
(
weights
=
(
1.0
,)
*
(
plan_arch
[
"dim"
]
*
2
))
s_param
=
False
if
(
"aspect_ratios"
in
_plan_anchors
)
and
\
s_param
=
(
(
_plan_anchors
[
"aspect_ratios"
]
is
not
None
)
else
True
False
anchor_generator
=
get_anchor_generator
(
if
(
"aspect_ratios"
in
_plan_anchors
)
plan_arch
[
"dim"
],
s_param
=
s_param
)(
**
_plan_anchors
)
and
(
_plan_anchors
[
"aspect_ratios"
]
is
not
None
)
else
True
)
anchor_generator
=
get_anchor_generator
(
plan_arch
[
"dim"
],
s_param
=
s_param
)(
**
_plan_anchors
)
encoder
=
cls
.
_build_encoder
(
encoder
=
cls
.
_build_encoder
(
plan_arch
=
plan_arch
,
plan_arch
=
plan_arch
,
...
@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
model_cfg
=
model_cfg
,
model_cfg
=
model_cfg
,
classifier
=
classifier
,
classifier
=
classifier
,
regressor
=
regressor
,
regressor
=
regressor
,
coder
=
coder
coder
=
coder
,
)
)
segmenter
=
cls
.
_build_segmenter
(
segmenter
=
cls
.
_build_segmenter
(
plan_arch
=
plan_arch
,
plan_arch
=
plan_arch
,
...
@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA):
remove_small_boxes
=
plan_arch
.
get
(
"remove_small_boxes"
,
0.01
)
remove_small_boxes
=
plan_arch
.
get
(
"remove_small_boxes"
,
0.01
)
nms_thresh
=
plan_arch
.
get
(
"nms_thresh"
,
0.6
)
nms_thresh
=
plan_arch
.
get
(
"nms_thresh"
,
0.6
)
logger
.
info
(
f
"Model Inference Summary:
\n
"
#
logger.info(f"Model Inference Summary: \n"
f
"detections_per_img:
{
detections_per_img
}
\n
"
#
f"detections_per_img: {detections_per_img} \n"
f
"score_thresh:
{
score_thresh
}
\n
"
#
f"score_thresh: {score_thresh} \n"
f
"topk_candidates:
{
topk_candidates
}
\n
"
#
f"topk_candidates: {topk_candidates} \n"
f
"remove_small_boxes:
{
remove_small_boxes
}
\n
"
#
f"remove_small_boxes: {remove_small_boxes} \n"
f
"nms_thresh:
{
nms_thresh
}
"
,
#
f"nms_thresh: {nms_thresh}",
)
#
)
return
BaseRetinaNet
(
return
BaseRetinaNet
(
dim
=
plan_arch
[
"dim"
],
dim
=
plan_arch
[
"dim"
],
...
@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
EncoderType: encoder instance
EncoderType: encoder instance
"""
"""
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: encoder
{
cls
.
encoder_cls
.
__name__
}
:
{
model_cfg
[
'encoder_kwargs'
]
}
"
)
logger
.
info
(
f
"Building:: encoder
{
cls
.
encoder_cls
.
__name__
}
:
{
model_cfg
[
'encoder_kwargs'
]
}
"
)
encoder
=
cls
.
encoder_cls
(
encoder
=
cls
.
encoder_cls
(
conv
=
conv
,
conv
=
conv
,
conv_kernels
=
plan_arch
[
"conv_kernels"
],
conv_kernels
=
plan_arch
[
"conv_kernels"
],
...
@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
start_channels
=
plan_arch
[
"start_channels"
],
start_channels
=
plan_arch
[
"start_channels"
],
stage_kwargs
=
None
,
stage_kwargs
=
None
,
max_channels
=
plan_arch
.
get
(
"max_channels"
,
320
),
max_channels
=
plan_arch
.
get
(
"max_channels"
,
320
),
**
model_cfg
[
'
encoder_kwargs
'
],
**
model_cfg
[
"
encoder_kwargs
"
],
)
)
return
encoder
return
encoder
...
@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
DecoderType: decoder instance
DecoderType: decoder instance
"""
"""
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: decoder
{
cls
.
decoder_cls
.
__name__
}
:
{
model_cfg
[
'decoder_kwargs'
]
}
"
)
logger
.
info
(
f
"Building:: decoder
{
cls
.
decoder_cls
.
__name__
}
:
{
model_cfg
[
'decoder_kwargs'
]
}
"
)
decoder
=
cls
.
decoder_cls
(
decoder
=
cls
.
decoder_cls
(
conv
=
conv
,
conv
=
conv
,
conv_kernels
=
plan_arch
[
"conv_kernels"
],
conv_kernels
=
plan_arch
[
"conv_kernels"
],
...
@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
in_channels
=
encoder
.
get_channels
(),
in_channels
=
encoder
.
get_channels
(),
decoder_levels
=
plan_arch
[
"decoder_levels"
],
decoder_levels
=
plan_arch
[
"decoder_levels"
],
fixed_out_channels
=
plan_arch
[
"fpn_channels"
],
fixed_out_channels
=
plan_arch
[
"fpn_channels"
],
**
model_cfg
[
'
decoder_kwargs
'
],
**
model_cfg
[
"
decoder_kwargs
"
],
)
)
return
decoder
return
decoder
...
@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
"""
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
name
=
cls
.
head_classifier_cls
.
__name__
name
=
cls
.
head_classifier_cls
.
__name__
kwargs
=
model_cfg
[
'
head_classifier_kwargs
'
]
kwargs
=
model_cfg
[
"
head_classifier_kwargs
"
]
logger
.
info
(
f
"Building:: classifier
{
name
}
:
{
kwargs
}
"
)
logger
.
info
(
f
"Building:: classifier
{
name
}
:
{
kwargs
}
"
)
classifier
=
cls
.
head_classifier_cls
(
classifier
=
cls
.
head_classifier_cls
(
...
@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
"""
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
name
=
cls
.
head_regressor_cls
.
__name__
name
=
cls
.
head_regressor_cls
.
__name__
kwargs
=
model_cfg
[
'
head_regressor_kwargs
'
]
kwargs
=
model_cfg
[
"
head_regressor_kwargs
"
]
logger
.
info
(
f
"Building:: regressor
{
name
}
:
{
kwargs
}
"
)
logger
.
info
(
f
"Building:: regressor
{
name
}
:
{
kwargs
}
"
)
regressor
=
cls
.
head_regressor_cls
(
regressor
=
cls
.
head_regressor_cls
(
...
@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA):
HeadType: instantiated head
HeadType: instantiated head
"""
"""
head_name
=
cls
.
head_cls
.
__name__
head_name
=
cls
.
head_cls
.
__name__
head_kwargs
=
model_cfg
[
'
head_kwargs
'
]
head_kwargs
=
model_cfg
[
"
head_kwargs
"
]
sampler_name
=
cls
.
head_sampler_cls
.
__name__
sampler_name
=
cls
.
head_sampler_cls
.
__name__
sampler_kwargs
=
model_cfg
[
'
head_sampler_kwargs
'
]
sampler_kwargs
=
model_cfg
[
"
head_sampler_kwargs
"
]
logger
.
info
(
f
"Building:: head
{
head_name
}
:
{
head_kwargs
}
"
logger
.
info
(
f
"sampler
{
sampler_name
}
:
{
sampler_kwargs
}
"
)
f
"Building:: head
{
head_name
}
:
{
head_kwargs
}
"
f
"sampler
{
sampler_name
}
:
{
sampler_kwargs
}
"
)
sampler
=
cls
.
head_sampler_cls
(
**
sampler_kwargs
)
sampler
=
cls
.
head_sampler_cls
(
**
sampler_kwargs
)
head
=
cls
.
head_cls
(
head
=
cls
.
head_cls
(
classifier
=
classifier
,
classifier
=
classifier
,
...
@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
"""
if
cls
.
segmenter_cls
is
not
None
:
if
cls
.
segmenter_cls
is
not
None
:
name
=
cls
.
segmenter_cls
.
__name__
name
=
cls
.
segmenter_cls
.
__name__
kwargs
=
model_cfg
[
'
segmenter_kwargs
'
]
kwargs
=
model_cfg
[
"
segmenter_kwargs
"
]
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: segmenter
{
name
}
{
kwargs
}
"
)
logger
.
info
(
f
"Building:: segmenter
{
name
}
{
kwargs
}
"
)
...
@@ -661,14 +688,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -661,14 +688,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
3
:
{
3
:
{
"boxes"
:
BoxEnsemblerSelective
,
"boxes"
:
BoxEnsemblerSelective
,
"seg"
:
SegmentationEnsembler
,
"seg"
:
SegmentationEnsembler
,
}
}
,
}
}
if
dim
==
2
:
if
dim
==
2
:
raise
NotImplementedError
raise
NotImplementedError
return
_lookup
[
dim
][
key
]
return
_lookup
[
dim
][
key
]
@
classmethod
@
classmethod
def
get_predictor
(
cls
,
def
get_predictor
(
cls
,
plan
:
Dict
,
plan
:
Dict
,
models
:
Sequence
[
RetinaUNetModule
],
models
:
Sequence
[
RetinaUNetModule
],
num_tta_transforms
:
int
=
None
,
num_tta_transforms
:
int
=
None
,
...
@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA):
num_tta_transforms
=
8
if
plan
[
"network_dim"
]
==
3
else
4
num_tta_transforms
=
8
if
plan
[
"network_dim"
]
==
3
else
4
# setup
# setup
tta_transforms
,
tta_inverse_transforms
=
\
tta_transforms
,
tta_inverse_transforms
=
get_tta_transforms
(
get_tta_transforms
(
num_tta_transforms
,
True
)
num_tta_transforms
,
True
logger
.
info
(
f
"Using
{
len
(
tta_transforms
)
}
tta transformations for prediction (one dummy trafo)."
)
)
logger
.
info
(
f
"Using
{
len
(
tta_transforms
)
}
tta transformations for prediction (one dummy trafo)."
)
ensembler
=
{
"boxes"
:
partial
(
ensembler
=
{
"boxes"
:
partial
(
cls
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
cls
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
parameters
=
inferene_plan
,
parameters
=
inferene_plan
,
)}
)
}
if
do_seg
:
if
do_seg
:
ensembler
[
"seg"
]
=
partial
(
ensembler
[
"seg"
]
=
partial
(
cls
.
get_ensembler_cls
(
key
=
"seg"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
cls
.
get_ensembler_cls
(
key
=
"seg"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
...
@@ -711,7 +744,8 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -711,7 +744,8 @@ class RetinaUNetModule(LightningBaseModuleSWA):
predictor
.
pre_transform
=
Inference2D
([
"data"
])
predictor
.
pre_transform
=
Inference2D
([
"data"
])
return
predictor
return
predictor
def
sweep
(
self
,
def
sweep
(
self
,
cfg
:
dict
,
cfg
:
dict
,
save_dir
:
os
.
PathLike
,
save_dir
:
os
.
PathLike
,
train_data_dir
:
os
.
PathLike
,
train_data_dir
:
os
.
PathLike
,
...
@@ -767,7 +801,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
...
@@ -767,7 +801,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
)
)
logger
.
info
(
"Start parameter sweep..."
)
logger
.
info
(
"Start parameter sweep..."
)
ensembler_cls
=
self
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
self
.
plan
[
"network_dim"
])
ensembler_cls
=
self
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
self
.
plan
[
"network_dim"
]
)
sweeper
=
BoxSweeper
(
sweeper
=
BoxSweeper
(
classes
=
[
item
for
_
,
item
in
cfg
[
"data"
][
"labels"
].
items
()],
classes
=
[
item
for
_
,
item
in
cfg
[
"data"
][
"labels"
].
items
()],
pred_dir
=
prediction_dir
,
pred_dir
=
prediction_dir
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment