Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
ede95851
Commit
ede95851
authored
Apr 22, 2021
by
mibaumgartner
Browse files
ptmodule
parent
4f533dd8
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1020 additions
and
0 deletions
+1020
-0
nndet/ptmodule/__init__.py
nndet/ptmodule/__init__.py
+6
-0
nndet/ptmodule/base_module.py
nndet/ptmodule/base_module.py
+201
-0
nndet/ptmodule/retinaunet/__init__.py
nndet/ptmodule/retinaunet/__init__.py
+3
-0
nndet/ptmodule/retinaunet/base.py
nndet/ptmodule/retinaunet/base.py
+771
-0
nndet/ptmodule/retinaunet/v001.py
nndet/ptmodule/retinaunet/v001.py
+39
-0
No files found.
nndet/ptmodule/__init__.py
0 → 100644
View file @
ede95851
from
typing
import
Mapping
,
Type
from
nndet.utils.registry
import
Registry
from
nndet.ptmodule.base_module
import
LightningBaseModule
MODULE_REGISTRY
:
Mapping
[
str
,
Type
[
LightningBaseModule
]]
=
Registry
()
from
nndet.ptmodule.retinaunet
import
*
nndet/ptmodule/base_module.py
0 → 100644
View file @
ede95851
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
__future__
import
annotations
import
os
from
time
import
time
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Sequence
,
Hashable
,
Type
,
TypeVar
import
torch
import
pytorch_lightning
as
pl
from
pytorch_lightning.core.memory
import
ModelSummary
from
loguru
import
logger
from
nndet.io.load
import
save_txt
from
nndet.inference.predictor
import
Predictor
class
LightningBaseModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
model_cfg
:
dict
,
trainer_cfg
:
dict
,
plan
:
dict
,
**
kwargs
):
"""
Provides a base module which is used inside of nnDetection.
All lightning modules of nnDetection should be derifed from this!
Args:
model_cfg: model configuration. Check :method:`from_config_plan`
for more information
trainer_cfg: trainer information
plan: contains parameters which were derived from the planning
stage
"""
super
().
__init__
()
self
.
model_cfg
=
model_cfg
self
.
trainer_cfg
=
trainer_cfg
self
.
plan
=
plan
self
.
model
=
self
.
from_config_plan
(
model_cfg
=
self
.
model_cfg
,
plan_arch
=
self
.
plan
[
"architecture"
],
plan_anchors
=
self
.
plan
[
"anchors"
],
)
self
.
example_input_array_shape
=
(
1
,
plan
[
"architecture"
][
"in_channels"
],
*
plan
[
"patch_size"
],
)
self
.
epoch_start_tic
=
0
self
.
epoch_end_toc
=
0
@
property
def
max_epochs
(
self
):
"""
Number of epochs to train
"""
return
self
.
trainer_cfg
[
"max_num_epochs"
]
def
on_epoch_start
(
self
)
->
None
:
"""
Save time
"""
self
.
epoch_start_tic
=
time
()
return
super
().
on_epoch_start
()
def
validation_epoch_end
(
self
,
validation_step_outputs
):
"""
Print time of epoch
(needed for cluster where progress bar is deactivated)
"""
self
.
epoch_end_toc
=
time
()
logger
.
info
(
f
"This epoch took
{
int
(
self
.
epoch_end_toc
-
self
.
epoch_start_tic
)
}
s"
)
return
super
().
validation_epoch_end
(
validation_step_outputs
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Used to generate summary
Do not(!) use this for inference. This will only forward
the input through the network which does not include
detection spcific postprocessing!
"""
return
self
.
model
(
x
)
@
property
def
example_input_array
(
self
):
"""
Create example input
"""
return
torch
.
zeros
(
*
self
.
example_input_array_shape
)
def
summarize
(
self
,
mode
:
Optional
[
str
])
->
Optional
[
ModelSummary
]:
"""
Save model summary as txt
"""
summary
=
super
().
summarize
(
mode
=
mode
)
save_txt
(
summary
,
"./network"
)
return
summary
def
inference_step
(
self
,
batch
:
Any
,
**
kwargs
)
->
Dict
[
str
,
Any
]:
"""
Prediction method used by nnDetection predictor class
"""
return
self
.
model
.
inference_step
(
batch
,
**
kwargs
)
@
classmethod
def
from_config_plan
(
cls
,
model_cfg
:
dict
,
plan_arch
:
dict
,
plan_anchors
:
dict
,
log_num_anchors
:
str
=
None
,
**
kwargs
,
):
"""
Used to generate the model
"""
raise
NotImplementedError
@
staticmethod
def
get_ensembler_cls
(
key
:
Hashable
,
dim
:
int
)
->
Callable
:
"""
Get ensembler classes to combine multiple predictions
Needs to be overwritten in subclasses!
"""
raise
NotImplementedError
@
classmethod
def
get_predictor
(
cls
,
plan
:
Dict
,
models
:
Sequence
[
LightningBaseModule
],
num_tta_transforms
:
int
=
None
,
**
kwargs
)
->
Type
[
Predictor
]:
"""
Get predictor
Needs to be overwritten in subclasses!
"""
raise
NotImplementedError
def
sweep
(
self
,
cfg
:
dict
,
save_dir
:
os
.
PathLike
,
train_data_dir
:
os
.
PathLike
,
case_ids
:
Sequence
[
str
],
run_prediction
:
bool
=
True
,
)
->
Dict
[
str
,
Any
]:
"""
Sweep parameters to find the best predictions
Needs to be overwritten in subclasses!
Args:
cfg: config used for training
save_dir: save dir used for training
train_data_dir: directory where preprocessed training/validation
data is located
case_ids: case identifies to prepare and predict
run_prediction: predict cases
**kwargs: keyword arguments passed to predict function
"""
raise
NotImplementedError
class
LightningBaseModuleSWA
(
LightningBaseModule
):
@
property
def
max_epochs
(
self
):
"""
Number of epochs to train
"""
return
self
.
trainer_cfg
[
"max_num_epochs"
]
+
self
.
trainer_cfg
[
"swa_epochs"
]
def
configure_callbacks
(
self
):
from
nndet.training.swa
import
SWACycleLinear
callbacks
=
[]
callbacks
.
append
(
SWACycleLinear
(
swa_epoch_start
=
self
.
trainer_cfg
[
"max_num_epochs"
],
cycle_initial_lr
=
self
.
trainer_cfg
[
"initial_lr"
]
/
10.
,
cycle_final_lr
=
self
.
trainer_cfg
[
"initial_lr"
]
/
1000.
,
num_iterations_per_epoch
=
self
.
trainer_cfg
[
"num_train_batches_per_epoch"
],
)
)
return
callbacks
LightningBaseModuleType
=
TypeVar
(
'LightningBaseModuleType'
,
bound
=
LightningBaseModule
)
nndet/ptmodule/retinaunet/__init__.py
0 → 100644
View file @
ede95851
from
nndet.ptmodule.retinaunet.base
import
RetinaUNetModule
from
nndet.ptmodule.retinaunet.v001
import
RetinaUNetV001
from
nndet.ptmodule.retinaunet.c010
import
RetinaUNetC010
nndet/ptmodule/retinaunet/base.py
0 → 100644
View file @
ede95851
This diff is collapsed.
Click to expand it.
nndet/ptmodule/retinaunet/v001.py
0 → 100644
View file @
ede95851
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
nndet.ptmodule.retinaunet.base
import
RetinaUNetModule
from
nndet.detection.boxes.matcher
import
ATSSMatcher
from
nndet.models.heads.classifier
import
BCECLassifier
from
nndet.models.heads.regressor
import
GIoURegressor
from
nndet.models.heads.comb
import
DetectionHeadHNMNative
from
nndet.models.heads.segmenter
import
DiCESegmenterFgBg
from
nndet.models.conv
import
ConvInstanceRelu
,
ConvGroupRelu
from
nndet.ptmodule
import
MODULE_REGISTRY
@
MODULE_REGISTRY
.
register
class
RetinaUNetV001
(
RetinaUNetModule
):
base_conv_cls
=
ConvInstanceRelu
head_conv_cls
=
ConvGroupRelu
head_cls
=
DetectionHeadHNMNative
head_classifier_cls
=
BCECLassifier
head_regressor_cls
=
GIoURegressor
matcher_cls
=
ATSSMatcher
segmenter_cls
=
DiCESegmenterFgBg
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment