Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
ede95851
Commit
ede95851
authored
Apr 22, 2021
by
mibaumgartner
Browse files
ptmodule
parent
4f533dd8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1020 additions
and
0 deletions
+1020
-0
nndet/ptmodule/__init__.py
nndet/ptmodule/__init__.py
+6
-0
nndet/ptmodule/base_module.py
nndet/ptmodule/base_module.py
+201
-0
nndet/ptmodule/retinaunet/__init__.py
nndet/ptmodule/retinaunet/__init__.py
+3
-0
nndet/ptmodule/retinaunet/base.py
nndet/ptmodule/retinaunet/base.py
+771
-0
nndet/ptmodule/retinaunet/v001.py
nndet/ptmodule/retinaunet/v001.py
+39
-0
No files found.
nndet/ptmodule/__init__.py
0 → 100644
View file @
ede95851
from
typing
import
Mapping
,
Type
from
nndet.utils.registry
import
Registry
from
nndet.ptmodule.base_module
import
LightningBaseModule
MODULE_REGISTRY
:
Mapping
[
str
,
Type
[
LightningBaseModule
]]
=
Registry
()
from
nndet.ptmodule.retinaunet
import
*
nndet/ptmodule/base_module.py
0 → 100644
View file @
ede95851
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
__future__
import
annotations
import
os
from
time
import
time
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Sequence
,
Hashable
,
Type
,
TypeVar
import
torch
import
pytorch_lightning
as
pl
from
pytorch_lightning.core.memory
import
ModelSummary
from
loguru
import
logger
from
nndet.io.load
import
save_txt
from
nndet.inference.predictor
import
Predictor
class
LightningBaseModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
"""
Provides a base module which is used inside of nnDetection.
All lightning modules of nnDetection should be derifed from this!
Args:
model_cfg: model configuration. Check :method:`from_config_plan`
for more information
trainer_cfg: trainer information
plan: contains parameters which were derived from the planning
stage
"""
super
().
__init__
()
self
.
model_cfg
=
model_cfg
self
.
trainer_cfg
=
trainer_cfg
self
.
plan
=
plan
self
.
model
=
self
.
from_config_plan
(
model_cfg
=
self
.
model_cfg
,
plan_arch
=
self
.
plan
[
"architecture"
],
plan_anchors
=
self
.
plan
[
"anchors"
],
)
self
.
example_input_array_shape
=
(
1
,
plan
[
"architecture"
][
"in_channels"
],
*
plan
[
"patch_size"
],
)
self
.
epoch_start_tic
=
0
self
.
epoch_end_toc
=
0
@
property
def
max_epochs
(
self
):
"""
Number of epochs to train
"""
return
self
.
trainer_cfg
[
"max_num_epochs"
]
def
on_epoch_start
(
self
)
->
None
:
"""
Save time
"""
self
.
epoch_start_tic
=
time
()
return
super
().
on_epoch_start
()
def
validation_epoch_end
(
self
,
validation_step_outputs
):
"""
Print time of epoch
(needed for cluster where progress bar is deactivated)
"""
self
.
epoch_end_toc
=
time
()
logger
.
info
(
f
"This epoch took
{
int
(
self
.
epoch_end_toc
-
self
.
epoch_start_tic
)
}
s"
)
return
super
().
validation_epoch_end
(
validation_step_outputs
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Used to generate summary
Do not(!) use this for inference. This will only forward
the input through the network which does not include
detection spcific postprocessing!
"""
return
self
.
model
(
x
)
@
property
def
example_input_array
(
self
):
"""
Create example input
"""
return
torch
.
zeros
(
*
self
.
example_input_array_shape
)
def
summarize
(
self
,
mode
:
Optional
[
str
])
->
Optional
[
ModelSummary
]:
"""
Save model summary as txt
"""
summary
=
super
().
summarize
(
mode
=
mode
)
save_txt
(
summary
,
"./network"
)
return
summary
def
inference_step
(
self
,
batch
:
Any
,
**
kwargs
)
->
Dict
[
str
,
Any
]:
"""
Prediction method used by nnDetection predictor class
"""
return
self
.
model
.
inference_step
(
batch
,
**
kwargs
)
@
classmethod
def
from_config_plan
(
cls
,
model_cfg
:
dict
,
plan_arch
:
dict
,
plan_anchors
:
dict
,
log_num_anchors
:
str
=
None
,
**
kwargs
,
):
"""
Used to generate the model
"""
raise
NotImplementedError
@
staticmethod
def
get_ensembler_cls
(
key
:
Hashable
,
dim
:
int
)
->
Callable
:
"""
Get ensembler classes to combine multiple predictions
Needs to be overwritten in subclasses!
"""
raise
NotImplementedError
@
classmethod
def
get_predictor
(
cls
,
plan
:
Dict
,
models
:
Sequence
[
LightningBaseModule
],
num_tta_transforms
:
int
=
None
,
**
kwargs
)
->
Type
[
Predictor
]:
"""
Get predictor
Needs to be overwritten in subclasses!
"""
raise
NotImplementedError
def
sweep
(
self
,
cfg
:
dict
,
save_dir
:
os
.
PathLike
,
train_data_dir
:
os
.
PathLike
,
case_ids
:
Sequence
[
str
],
run_prediction
:
bool
=
True
,
)
->
Dict
[
str
,
Any
]:
"""
Sweep parameters to find the best predictions
Needs to be overwritten in subclasses!
Args:
cfg: config used for training
save_dir: save dir used for training
train_data_dir: directory where preprocessed training/validation
data is located
case_ids: case identifies to prepare and predict
run_prediction: predict cases
**kwargs: keyword arguments passed to predict function
"""
raise
NotImplementedError
class
LightningBaseModuleSWA
(
LightningBaseModule
):
@
property
def
max_epochs
(
self
):
"""
Number of epochs to train
"""
return
self
.
trainer_cfg
[
"max_num_epochs"
]
+
self
.
trainer_cfg
[
"swa_epochs"
]
def
configure_callbacks
(
self
):
from
nndet.training.swa
import
SWACycleLinear
callbacks
=
[]
callbacks
.
append
(
SWACycleLinear
(
swa_epoch_start
=
self
.
trainer_cfg
[
"max_num_epochs"
],
cycle_initial_lr
=
self
.
trainer_cfg
[
"initial_lr"
]
/
10.
,
cycle_final_lr
=
self
.
trainer_cfg
[
"initial_lr"
]
/
1000.
,
num_iterations_per_epoch
=
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
],
)
)
return
callbacks
LightningBaseModuleType
=
TypeVar
(
'LightningBaseModuleType'
,
bound
=
LightningBaseModule
)
nndet/ptmodule/retinaunet/__init__.py
0 → 100644
View file @
ede95851
from
nndet.ptmodule.retinaunet.base
import
RetinaUNetModule
from
nndet.ptmodule.retinaunet.v001
import
RetinaUNetV001
from
nndet.ptmodule.retinaunet.c010
import
RetinaUNetC010
nndet/ptmodule/retinaunet/base.py
0 → 100644
View file @
ede95851
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
__future__
import
annotations
import
os
import
copy
from
collections
import
defaultdict
from
pathlib
import
Path
from
functools
import
partial
from
typing
import
Callable
,
Hashable
,
Sequence
,
Dict
,
Any
,
Type
import
torch
import
numpy
as
np
from
loguru
import
logger
from
torchvision.models.detection.rpn
import
AnchorType
,
AnchorGenerator
from
nndet.utils.tensor
import
to_numpy
from
nndet.evaluator.det
import
BoxEvaluator
from
nndet.evaluator.seg
import
SegmentationEvaluator
from
nndet.detection.retina
import
BaseRetinaNet
from
nndet.detection.boxes.matcher
import
IoUMatcher
from
nndet.detection.boxes.sampler
import
HardNegativeSamplerBatched
from
nndet.detection.boxes.coder
import
CoderType
,
BoxCoderND
from
nndet.detection.boxes.anchors
import
get_anchor_generator
from
nndet.detection.boxes.utils
import
box_iou
from
nndet.ptmodule.base_module
import
LightningBaseModuleSWA
,
LightningBaseModule
from
nndet.models.conv
import
Generator
,
ConvInstanceRelu
,
ConvGroupRelu
from
nndet.models.blocks.basic
import
StackedConvBlock2
from
nndet.models.encoder.modular
import
EncoderType
,
Encoder
from
nndet.models.decoder.base
import
DecoderType
,
BaseUFPN
,
UFPNModular
from
nndet.models.heads.classifier
import
ClassifierType
,
CEClassifier
from
nndet.models.heads.regressor
import
RegressorType
,
L1Regressor
from
nndet.models.heads.comb
import
HeadType
,
DetectionHeadHNM
from
nndet.models.heads.segmenter
import
SegmenterType
,
DiCESegmenter
from
nndet.training.optimizer
import
get_params_no_wd_on_norm
from
nndet.training.learning_rate
import
LinearWarmupPolyLR
from
nndet.inference.predictor
import
Predictor
from
nndet.inference.sweeper
import
BoxSweeper
from
nndet.inference.transforms
import
get_tta_transforms
,
Inference2D
from
nndet.inference.loading
import
load_final_model
from
nndet.inference.helper
import
predict_dir
from
nndet.inference.ensembler.segmentation
import
SegmentationEnsembler
from
nndet.inference.ensembler.detection
import
BoxEnsemblerSelective
,
BoxEnsemblerSelective2D
from
rising.transforms
import
Compose
from
nndet.io.transforms
import
Instances2Boxes
,
Instances2Segmentation
,
FindInstances
class
RetinaUNetModule
(
LightningBaseModuleSWA
):
base_conv_cls
=
ConvInstanceRelu
head_conv_cls
=
ConvGroupRelu
block
=
StackedConvBlock2
encoder_cls
=
Encoder
decoder_cls
=
UFPNModular
matcher_cls
=
IoUMatcher
head_cls
=
DetectionHeadHNM
head_classifier_cls
=
CEClassifier
head_regressor_cls
=
L1Regressor
head_sampler_cls
=
HardNegativeSamplerBatched
segmenter_cls
=
DiCESegmenter
def
__init__
(
self
,
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
"""
RetinaUNet Lightning Module Skeleton
Args:
model_cfg: model configuration. Check :method:`from_config_plan`
for more information
trainer_cfg: trainer information
plan: contains parameters which were derived from the planning
stage
"""
super
().
__init__
(
model_cfg
=
model_cfg
,
trainer_cfg
=
trainer_cfg
,
plan
=
plan
,
)
_classes
=
[
f
"class
{
c
}
"
for
c
in
range
(
plan
[
"architecture"
][
"classifier_classes"
])]
self
.
box_evaluator
=
BoxEvaluator
.
create
(
classes
=
_classes
,
fast
=
True
,
save_dir
=
None
,
)
self
.
seg_evaluator
=
SegmentationEvaluator
.
create
()
self
.
pre_trafo
=
Compose
(
FindInstances
(
instance_key
=
"target"
,
save_key
=
"present_instances"
,
),
Instances2Boxes
(
instance_key
=
"target"
,
map_key
=
"instance_mapping"
,
box_key
=
"boxes"
,
class_key
=
"classes"
,
present_instances
=
"present_instances"
,
),
Instances2Segmentation
(
instance_key
=
"target"
,
map_key
=
"instance_mapping"
,
present_instances
=
"present_instances"
,
)
)
self
.
eval_score_key
=
"mAP_IoU_0.10_0.50_0.05_MaxDet_100"
def
training_step
(
self
,
batch
,
batch_idx
):
"""
Computes a single training step
See :class:`BaseRetinaNet` for more information
"""
with
torch
.
no_grad
():
batch
=
self
.
pre_trafo
(
**
batch
)
losses
,
_
=
self
.
model
.
train_step
(
images
=
batch
[
"data"
],
targets
=
{
"target_boxes"
:
batch
[
"boxes"
],
"target_classes"
:
batch
[
"classes"
],
"target_seg"
:
batch
[
'target'
][:,
0
]
# Remove channel dimension
},
evaluation
=
False
,
batch_num
=
batch_idx
,
)
loss
=
sum
(
losses
.
values
())
return
{
"loss"
:
loss
,
**
{
key
:
l
.
detach
().
item
()
for
key
,
l
in
losses
.
items
()}}
def
validation_step
(
self
,
batch
,
batch_idx
):
"""
Computes a single validation step (same as train step but with
additional prediciton processing)
See :class:`BaseRetinaNet` for more information
"""
with
torch
.
no_grad
():
batch
=
self
.
pre_trafo
(
**
batch
)
targets
=
{
"target_boxes"
:
batch
[
"boxes"
],
"target_classes"
:
batch
[
"classes"
],
"target_seg"
:
batch
[
'target'
][:,
0
]
# Remove channel dimension
}
losses
,
prediction
=
self
.
model
.
train_step
(
images
=
batch
[
"data"
],
targets
=
targets
,
evaluation
=
True
,
batch_num
=
batch_idx
,
)
loss
=
sum
(
losses
.
values
())
self
.
evaluation_step
(
prediction
=
prediction
,
targets
=
targets
)
return
{
"loss"
:
loss
.
detach
().
item
(),
**
{
key
:
l
.
detach
().
item
()
for
key
,
l
in
losses
.
items
()}}
def
evaluation_step
(
self
,
prediction
:
dict
,
targets
:
dict
,
):
"""
Perform an evaluation step to add predictions and gt to
caching mechanism which is evaluated at the end of the epoch
Args:
prediction: predictions obtained from model
'pred_boxes': List[Tensor]: predicted bounding boxes for
each image List[[R, dim * 2]]
'pred_scores': List[Tensor]: predicted probability for
the class List[[R]]
'pred_labels': List[Tensor]: predicted class List[[R]]
'pred_seg': Tensor: predicted segmentation [N, dims]
targets: ground truth
`target_boxes` (List[Tensor]): ground truth bounding boxes
(x1, y1, x2, y2, (z1, z2))[X, dim * 2], X= number of ground
truth boxes in image
`target_classes` (List[Tensor]): ground truth class per box
(classes start from 0) [X], X= number of ground truth
boxes in image
`target_seg` (Tensor): segmentation ground truth (if seg was
found in input dict)
"""
pred_boxes
=
to_numpy
(
prediction
[
"pred_boxes"
])
pred_classes
=
to_numpy
(
prediction
[
"pred_labels"
])
pred_scores
=
to_numpy
(
prediction
[
"pred_scores"
])
gt_boxes
=
to_numpy
(
targets
[
"target_boxes"
])
gt_classes
=
to_numpy
(
targets
[
"target_classes"
])
gt_ignore
=
None
self
.
box_evaluator
.
run_online_evaluation
(
pred_boxes
=
pred_boxes
,
pred_classes
=
pred_classes
,
pred_scores
=
pred_scores
,
gt_boxes
=
gt_boxes
,
gt_classes
=
gt_classes
,
gt_ignore
=
gt_ignore
,
)
pred_seg
=
to_numpy
(
prediction
[
"pred_seg"
])
gt_seg
=
to_numpy
(
targets
[
"target_seg"
])
self
.
seg_evaluator
.
run_online_evaluation
(
seg_probs
=
pred_seg
,
target
=
gt_seg
,
)
def
training_epoch_end
(
self
,
training_step_outputs
):
"""
Log train loss to loguru logger
"""
# process and log losses
vals
=
defaultdict
(
list
)
for
_val
in
training_step_outputs
:
for
_k
,
_v
in
_val
.
items
():
if
_k
==
"loss"
:
vals
[
_k
].
append
(
_v
.
detach
().
item
())
else
:
vals
[
_k
].
append
(
_v
)
for
_key
,
_vals
in
vals
.
items
():
mean_val
=
np
.
mean
(
_vals
)
if
_key
==
"loss"
:
logger
.
info
(
f
"Train loss reached:
{
mean_val
:
0.5
f
}
"
)
self
.
log
(
f
"train_
{
_key
}
"
,
mean_val
,
sync_dist
=
True
)
return
super
().
training_epoch_end
(
training_step_outputs
)
def
validation_epoch_end
(
self
,
validation_step_outputs
):
"""
Log val loss to loguru logger
"""
# process and log losses
vals
=
defaultdict
(
list
)
for
_val
in
validation_step_outputs
:
for
_k
,
_v
in
_val
.
items
():
vals
[
_k
].
append
(
_v
)
for
_key
,
_vals
in
vals
.
items
():
mean_val
=
np
.
mean
(
_vals
)
if
_key
==
"loss"
:
logger
.
info
(
f
"Val loss reached:
{
mean_val
:
0.5
f
}
"
)
self
.
log
(
f
"val_
{
_key
}
"
,
mean_val
,
sync_dist
=
True
)
# process and log metrics
self
.
evaluation_end
()
return
super
().
validation_epoch_end
(
validation_step_outputs
)
def
evaluation_end
(
self
):
"""
Uses the cached values from `evaluation_step` to perform the evaluation
of the epoch
"""
metric_scores
,
_
=
self
.
box_evaluator
.
finish_online_evaluation
()
self
.
box_evaluator
.
reset
()
logger
.
info
(
f
"mAP@0.1:0.5:0.05:
{
metric_scores
[
'mAP_IoU_0.10_0.50_0.05_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.1:
{
metric_scores
[
'AP_IoU_0.10_MaxDet_100'
]:
0.3
f
}
"
f
"AP@0.5:
{
metric_scores
[
'AP_IoU_0.50_MaxDet_100'
]:
0.3
f
}
"
)
seg_scores
,
_
=
self
.
seg_evaluator
.
finish_online_evaluation
()
self
.
seg_evaluator
.
reset
()
metric_scores
.
update
(
seg_scores
)
logger
.
info
(
f
"Proxy FG Dice:
{
seg_scores
[
'seg_dice'
]:
0.3
f
}
"
)
for
key
,
item
in
metric_scores
.
items
():
self
.
log
(
f
'
{
key
}
'
,
item
,
on_step
=
None
,
on_epoch
=
True
,
prog_bar
=
False
,
logger
=
True
)
def
configure_optimizers
(
self
):
"""
Configure optimizer and scheduler
Base configuration is SGD with LinearWarmup and PolyLR learning rate
schedule
"""
# configure optimizer
logger
.
info
(
f
"Running: initial_lr
{
self
.
trainer_cfg
[
'initial_lr'
]
}
"
f
"weight_decay
{
self
.
trainer_cfg
[
'weight_decay'
]
}
"
f
"SGD with momentum
{
self
.
trainer_cfg
[
'sgd_momentum'
]
}
and "
f
"nesterov
{
self
.
trainer_cfg
[
'sgd_nesterov'
]
}
"
)
wd_groups
=
get_params_no_wd_on_norm
(
self
,
weight_decay
=
self
.
trainer_cfg
[
'weight_decay'
])
optimizer
=
torch
.
optim
.
SGD
(
wd_groups
,
self
.
trainer_cfg
[
"initial_lr"
],
weight_decay
=
self
.
trainer_cfg
[
"weight_decay"
],
momentum
=
self
.
trainer_cfg
[
"sgd_momentum"
],
nesterov
=
self
.
trainer_cfg
[
"sgd_nesterov"
],
)
# configure lr scheduler
num_iterations
=
self
.
trainer_cfg
[
"max_num_epochs"
]
*
\
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
]
scheduler
=
LinearWarmupPolyLR
(
optimizer
=
optimizer
,
warm_iterations
=
self
.
trainer_cfg
[
"warm_iterations"
],
warm_lr
=
self
.
trainer_cfg
[
"warm_lr"
],
poly_gamma
=
self
.
trainer_cfg
[
"poly_gamma"
],
num_iterations
=
num_iterations
)
return
[
optimizer
],
{
'scheduler'
:
scheduler
,
'interval'
:
'step'
}
@
classmethod
def
from_config_plan
(
cls
,
model_cfg
:
dict
,
plan_arch
:
dict
,
plan_anchors
:
dict
,
log_num_anchors
:
str
=
None
,
**
kwargs
,
):
"""
Create Configurable RetinaUNet
Args:
model_cfg: model configurations
See example configs for more info
plan_arch: plan architecture
`dim` (int): number of spatial dimensions
`in_channels` (int): number of input channels
`classifier_classes` (int): number of classes
`seg_classes` (int): number of classes
`start_channels` (int): number of start channels in encoder
`fpn_channels` (int): number of channels to use for FPN
`head_channels` (int): number of channels to use for head
`decoder_levels` (int): decoder levels to user for detection
plan_anchors: parameters for anchors (see
:class:`AnchorGenerator` for more info)
`stride`: stride
`aspect_ratios`: aspect ratios
`sizes`: sized for 2d acnhors
(`zsizes`: additional z sizes for 3d)
log_num_anchors: name of logger to use; if None, no logging
will be performed
**kwargs:
"""
logger
.
info
(
f
"Architecture overwrites:
{
model_cfg
[
'plan_arch_overwrites'
]
}
"
f
"Anchor overwrites:
{
model_cfg
[
'plan_anchors_overwrites'
]
}
"
)
logger
.
info
(
f
"Building architecture according to plan of
{
plan_arch
.
get
(
'arch_name'
,
'not_found'
)
}
"
)
plan_arch
.
update
(
model_cfg
[
"plan_arch_overwrites"
])
plan_anchors
.
update
(
model_cfg
[
"plan_anchors_overwrites"
])
logger
.
info
(
f
"Start channels:
{
plan_arch
[
'start_channels'
]
}
; "
f
"head channels:
{
plan_arch
[
'head_channels'
]
}
; "
f
"fpn channels:
{
plan_arch
[
'fpn_channels'
]
}
"
)
_plan_anchors
=
copy
.
deepcopy
(
plan_anchors
)
coder
=
BoxCoderND
(
weights
=
(
1.
,)
*
(
plan_arch
[
"dim"
]
*
2
))
s_param
=
False
if
(
"aspect_ratios"
in
_plan_anchors
)
and
\
(
_plan_anchors
[
"aspect_ratios"
]
is
not
None
)
else
True
anchor_generator
=
get_anchor_generator
(
plan_arch
[
"dim"
],
s_param
=
s_param
)(
**
_plan_anchors
)
encoder
=
cls
.
_build_encoder
(
plan_arch
=
plan_arch
,
model_cfg
=
model_cfg
,
)
decoder
=
cls
.
_build_decoder
(
encoder
=
encoder
,
plan_arch
=
plan_arch
,
model_cfg
=
model_cfg
,
)
matcher
=
cls
.
matcher_cls
(
similarity_fn
=
box_iou
,
**
model_cfg
[
"matcher_kwargs"
],
)
classifier
=
cls
.
_build_head_classifier
(
plan_arch
=
plan_arch
,
model_cfg
=
model_cfg
,
anchor_generator
=
anchor_generator
,
)
regressor
=
cls
.
_build_head_regressor
(
plan_arch
=
plan_arch
,
model_cfg
=
model_cfg
,
anchor_generator
=
anchor_generator
,
)
head
=
cls
.
_build_head
(
plan_arch
=
plan_arch
,
model_cfg
=
model_cfg
,
classifier
=
classifier
,
regressor
=
regressor
,
coder
=
coder
)
segmenter
=
cls
.
_build_segmenter
(
plan_arch
=
plan_arch
,
model_cfg
=
model_cfg
,
decoder
=
decoder
,
)
detections_per_img
=
plan_arch
.
get
(
"detections_per_img"
,
100
)
score_thresh
=
plan_arch
.
get
(
"score_thresh"
,
0
)
topk_candidates
=
plan_arch
.
get
(
"topk_candidates"
,
10000
)
remove_small_boxes
=
plan_arch
.
get
(
"remove_small_boxes"
,
0.01
)
nms_thresh
=
plan_arch
.
get
(
"nms_thresh"
,
0.6
)
logger
.
info
(
f
"Model Inference Summary:
\n
"
f
"detections_per_img:
{
detections_per_img
}
\n
"
f
"score_thresh:
{
score_thresh
}
\n
"
f
"topk_candidates:
{
topk_candidates
}
\n
"
f
"remove_small_boxes:
{
remove_small_boxes
}
\n
"
f
"nms_thresh:
{
nms_thresh
}
"
,
)
return
BaseRetinaNet
(
dim
=
plan_arch
[
"dim"
],
encoder
=
encoder
,
decoder
=
decoder
,
head
=
head
,
anchor_generator
=
anchor_generator
,
matcher
=
matcher
,
num_classes
=
plan_arch
[
"classifier_classes"
],
decoder_levels
=
plan_arch
[
"decoder_levels"
],
segmenter
=
segmenter
,
# model_max_instances_per_batch_element (in mdt per img, per class; here: per img)
detections_per_img
=
detections_per_img
,
score_thresh
=
score_thresh
,
topk_candidates
=
topk_candidates
,
remove_small_boxes
=
remove_small_boxes
,
nms_thresh
=
nms_thresh
,
)
@
classmethod
def
_build_encoder
(
cls
,
plan_arch
:
dict
,
model_cfg
:
dict
,
)
->
EncoderType
:
"""
Build encoder network
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
Returns:
EncoderType: encoder instance
"""
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: encoder
{
cls
.
encoder_cls
.
__name__
}
:
{
model_cfg
[
'encoder_kwargs'
]
}
"
)
encoder
=
cls
.
encoder_cls
(
conv
=
conv
,
conv_kernels
=
plan_arch
[
"conv_kernels"
],
strides
=
plan_arch
[
"strides"
],
block_cls
=
cls
.
block
,
in_channels
=
plan_arch
[
"in_channels"
],
start_channels
=
plan_arch
[
"start_channels"
],
stage_kwargs
=
None
,
max_channels
=
plan_arch
.
get
(
"max_channels"
,
320
),
**
model_cfg
[
'encoder_kwargs'
],
)
return
encoder
@
classmethod
def
_build_decoder
(
cls
,
plan_arch
:
dict
,
model_cfg
:
dict
,
encoder
:
EncoderType
,
)
->
DecoderType
:
"""
Build decoder network
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
Returns:
DecoderType: decoder instance
"""
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: decoder
{
cls
.
decoder_cls
.
__name__
}
:
{
model_cfg
[
'decoder_kwargs'
]
}
"
)
decoder
=
cls
.
decoder_cls
(
conv
=
conv
,
conv_kernels
=
plan_arch
[
"conv_kernels"
],
strides
=
encoder
.
get_strides
(),
in_channels
=
encoder
.
get_channels
(),
decoder_levels
=
plan_arch
[
"decoder_levels"
],
fixed_out_channels
=
plan_arch
[
"fpn_channels"
],
**
model_cfg
[
'decoder_kwargs'
],
)
return
decoder
@
classmethod
def
_build_head_classifier
(
cls
,
plan_arch
:
dict
,
model_cfg
:
dict
,
anchor_generator
:
AnchorType
,
)
->
ClassifierType
:
"""
Build classification subnetwork for detection head
Args:
anchor_generator: anchor generator instance
plan_arch: architecture settings
model_cfg: additional architecture settings
Returns:
ClassifierType: classification instance
"""
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
name
=
cls
.
head_classifier_cls
.
__name__
kwargs
=
model_cfg
[
'head_classifier_kwargs'
]
logger
.
info
(
f
"Building:: classifier
{
name
}
:
{
kwargs
}
"
)
classifier
=
cls
.
head_classifier_cls
(
conv
=
conv
,
in_channels
=
plan_arch
[
"fpn_channels"
],
internal_channels
=
plan_arch
[
"head_channels"
],
num_classes
=
plan_arch
[
"classifier_classes"
],
anchors_per_pos
=
anchor_generator
.
num_anchors_per_location
()[
0
],
num_levels
=
len
(
plan_arch
[
"decoder_levels"
]),
**
kwargs
,
)
return
classifier
@
classmethod
def
_build_head_regressor
(
cls
,
plan_arch
:
dict
,
model_cfg
:
dict
,
anchor_generator
:
AnchorType
,
)
->
RegressorType
:
"""
Build regression subnetwork for detection head
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
anchor_generator: anchor generator instance
Returns:
RegressorType: classification instance
"""
conv
=
Generator
(
cls
.
head_conv_cls
,
plan_arch
[
"dim"
])
name
=
cls
.
head_regressor_cls
.
__name__
kwargs
=
model_cfg
[
'head_regressor_kwargs'
]
logger
.
info
(
f
"Building:: regressor
{
name
}
:
{
kwargs
}
"
)
regressor
=
cls
.
head_regressor_cls
(
conv
=
conv
,
in_channels
=
plan_arch
[
"fpn_channels"
],
internal_channels
=
plan_arch
[
"head_channels"
],
anchors_per_pos
=
anchor_generator
.
num_anchors_per_location
()[
0
],
num_levels
=
len
(
plan_arch
[
"decoder_levels"
]),
**
kwargs
,
)
return
regressor
@
classmethod
def
_build_head
(
cls
,
plan_arch
:
dict
,
model_cfg
:
dict
,
classifier
:
ClassifierType
,
regressor
:
RegressorType
,
coder
:
CoderType
,
)
->
HeadType
:
"""
Build detection head
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
classifier: classifier instance
regressor: regressor instance
coder: coder instance to encode boxes
Returns:
HeadType: instantiated head
"""
head_name
=
cls
.
head_cls
.
__name__
head_kwargs
=
model_cfg
[
'head_kwargs'
]
sampler_name
=
cls
.
head_sampler_cls
.
__name__
sampler_kwargs
=
model_cfg
[
'head_sampler_kwargs'
]
logger
.
info
(
f
"Building:: head
{
head_name
}
:
{
head_kwargs
}
"
f
"sampler
{
sampler_name
}
:
{
sampler_kwargs
}
"
)
sampler
=
cls
.
head_sampler_cls
(
**
sampler_kwargs
)
head
=
cls
.
head_cls
(
classifier
=
classifier
,
regressor
=
regressor
,
coder
=
coder
,
sampler
=
sampler
,
log_num_anchors
=
None
,
**
head_kwargs
,
)
return
head
@
classmethod
def
_build_segmenter
(
cls
,
plan_arch
:
dict
,
model_cfg
:
dict
,
decoder
:
DecoderType
,
)
->
SegmenterType
:
"""
Build segmenter head
Args:
plan_arch: architecture settings
model_cfg: additional architecture settings
decoder: decoder instance
Returns:
SegmenterType: segmenter head
"""
if
cls
.
segmenter_cls
is
not
None
:
name
=
cls
.
segmenter_cls
.
__name__
kwargs
=
model_cfg
[
'segmenter_kwargs'
]
conv
=
Generator
(
cls
.
base_conv_cls
,
plan_arch
[
"dim"
])
logger
.
info
(
f
"Building:: segmenter
{
name
}
{
kwargs
}
"
)
segmenter
=
cls
.
segmenter_cls
(
conv
,
seg_classes
=
plan_arch
[
"seg_classes"
],
in_channels
=
decoder
.
get_channels
(),
decoder_levels
=
plan_arch
[
"decoder_levels"
],
**
kwargs
,
)
else
:
segmenter
=
None
return
segmenter
@
staticmethod
def
get_ensembler_cls
(
key
:
Hashable
,
dim
:
int
)
->
Callable
:
"""
Get ensembler classes to combine multiple predictions
Needs to be overwritten in subclasses!
"""
_lookup
=
{
2
:
{
"boxes"
:
BoxEnsemblerSelective2D
,
"seg"
:
SegmentationEnsembler
,
},
3
:
{
"boxes"
:
BoxEnsemblerSelective
,
"seg"
:
SegmentationEnsembler
,
}
}
return
_lookup
[
dim
][
key
]
@
classmethod
def
get_predictor
(
cls
,
plan
:
Dict
,
models
:
Sequence
[
RetinaUNetModule
],
num_tta_transforms
:
int
=
None
,
do_seg
:
bool
=
False
,
**
kwargs
,
)
->
Predictor
:
# process plan
crop_size
=
plan
[
"patch_size"
]
batch_size
=
plan
[
"batch_size"
]
inferene_plan
=
plan
.
get
(
"inference_plan"
,
{})
logger
.
info
(
f
"Found inference plan:
{
inferene_plan
}
for prediction"
)
if
num_tta_transforms
is
None
:
num_tta_transforms
=
8
if
plan
[
"network_dim"
]
==
3
else
4
# setup
tta_transforms
,
tta_inverse_transforms
=
\
get_tta_transforms
(
num_tta_transforms
,
True
)
logger
.
info
(
f
"Using
{
len
(
tta_transforms
)
}
tta transformations for prediction (one dummy trafo)."
)
ensembler
=
{
"boxes"
:
partial
(
cls
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
parameters
=
inferene_plan
,
)}
if
do_seg
:
ensembler
[
"seg"
]
=
partial
(
cls
.
get_ensembler_cls
(
key
=
"seg"
,
dim
=
plan
[
"network_dim"
]).
from_case
,
)
predictor
=
Predictor
(
ensembler
=
ensembler
,
models
=
models
,
crop_size
=
crop_size
,
tta_transforms
=
tta_transforms
,
tta_inverse_transforms
=
tta_inverse_transforms
,
batch_size
=
batch_size
,
**
kwargs
,
)
if
plan
[
"network_dim"
]
==
2
:
predictor
.
pre_transform
=
Inference2D
([
"data"
])
return
predictor
def
sweep
(
self
,
cfg
:
dict
,
save_dir
:
os
.
PathLike
,
train_data_dir
:
os
.
PathLike
,
case_ids
:
Sequence
[
str
],
run_prediction
:
bool
=
True
,
**
kwargs
,
)
->
Dict
[
str
,
Any
]:
"""
Sweep detection parameters to find the best predictions
Args:
cfg: config used for training
save_dir: save dir used for training
train_data_dir: directory where preprocessed training/validation
data is located
case_ids: case identifies to prepare and predict
run_prediction: predict cases
**kwargs: keyword arguments passed to predict function
Returns:
Dict: inference plan
e.g. (exact params depend on ensembler class usef for prediction)
`iou_thresh` (float): best IoU threshold
`score_thresh (float)`: best score threshold
`no_overlap` (bool): enable/disable class independent NMS (ciNMS)
"""
logger
.
info
(
f
"Running parameter sweep on
{
case_ids
}
"
)
train_data_dir
=
Path
(
train_data_dir
)
preprocessed_dir
=
train_data_dir
.
parent
processed_eval_labels
=
preprocessed_dir
/
"labelsTr"
_save_dir
=
save_dir
/
"sweep"
_save_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
prediction_dir
=
save_dir
/
"sweep_predictions"
prediction_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
run_prediction
:
logger
.
info
(
"Predict cases with default settings..."
)
predictor
=
predict_dir
(
source_dir
=
train_data_dir
,
target_dir
=
prediction_dir
,
cfg
=
cfg
,
plan
=
self
.
plan
,
source_models
=
save_dir
,
num_models
=
1
,
num_tta_transforms
=
None
,
case_ids
=
case_ids
,
save_state
=
True
,
model_fn
=
load_final_model
,
**
kwargs
,
)
logger
.
info
(
"Start parameter sweep..."
)
ensembler_cls
=
self
.
get_ensembler_cls
(
key
=
"boxes"
,
dim
=
self
.
plan
[
"network_dim"
])
sweeper
=
BoxSweeper
(
classes
=
[
item
for
_
,
item
in
cfg
[
"data"
][
"labels"
].
items
()],
pred_dir
=
prediction_dir
,
gt_dir
=
processed_eval_labels
,
target_metric
=
self
.
eval_score_key
,
ensembler_cls
=
ensembler_cls
,
save_dir
=
_save_dir
,
)
inference_plan
=
sweeper
.
run_postprocessing_sweep
()
return
inference_plan
nndet/ptmodule/retinaunet/v001.py
0 → 100644
View file @
ede95851
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
nndet.ptmodule.retinaunet.base
import
RetinaUNetModule
from
nndet.detection.boxes.matcher
import
ATSSMatcher
from
nndet.models.heads.classifier
import
BCECLassifier
from
nndet.models.heads.regressor
import
GIoURegressor
from
nndet.models.heads.comb
import
DetectionHeadHNMNative
from
nndet.models.heads.segmenter
import
DiCESegmenterFgBg
from
nndet.models.conv
import
ConvInstanceRelu
,
ConvGroupRelu
from
nndet.ptmodule
import
MODULE_REGISTRY
@
MODULE_REGISTRY
.
register
class
RetinaUNetV001
(
RetinaUNetModule
):
base_conv_cls
=
ConvInstanceRelu
head_conv_cls
=
ConvGroupRelu
head_cls
=
DetectionHeadHNMNative
head_classifier_cls
=
BCECLassifier
head_regressor_cls
=
GIoURegressor
matcher_cls
=
ATSSMatcher
segmenter_cls
=
DiCESegmenterFgBg
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment