"csrc/git@developer.sourcefind.cn:jerrrrry/infinilm.git" did not exist on "13a4154ab94f7fb13c2c906c04c492d7bd38cd57"
Commit 4560ce77 authored by mibaumgartner's avatar mibaumgartner
Browse files

models

parent 94d6ac20
from nndet.core.boxes.anchors import get_anchor_generator, compute_anchors_for_strides, \ from nndet.core.boxes.anchors import get_anchor_generator, compute_anchors_for_strides, \
AnchorGenerator2D, AnchorGenerator2DS, AnchorGenerator3D, AnchorGenerator3DS AnchorGenerator2D, AnchorGenerator2DS, AnchorGenerator3D, AnchorGenerator3DS
from nndet.core.boxes.clip import clip_boxes_to_image_, clip_boxes_to_image from nndet.core.boxes.clip import clip_boxes_to_image_, clip_boxes_to_image
from nndet.core.boxes.coder import BoxCoderND from nndet.core.boxes.coder import CoderType, BoxCoderND
from nndet.core.boxes.matcher import MatcherType, Matcher, IoUMatcher, ATSSMatcher from nndet.core.boxes.matcher import MatcherType, Matcher, IoUMatcher, ATSSMatcher
from nndet.core.boxes.nms import nms, batched_nms from nndet.core.boxes.nms import nms, batched_nms
from nndet.core.boxes.sampler import AbstractSampler, NegativeSampler, HardNegativeSampler, \ from nndet.core.boxes.sampler import AbstractSampler, NegativeSampler, HardNegativeSampler, \
......
from __future__ import division from __future__ import division
import math import math
from typing import Sequence from typing import Sequence, TypeVar
import torch import torch
from torch.jit.annotations import List, Tuple from torch.jit.annotations import List, Tuple
...@@ -235,3 +235,6 @@ class BoxCoderND(BoxCoder): ...@@ -235,3 +235,6 @@ class BoxCoderND(BoxCoder):
def decode_single(self, rel_codes: torch.Tensor, boxes: torch.Tensor): def decode_single(self, rel_codes: torch.Tensor, boxes: torch.Tensor):
dtype, device = rel_codes.dtype, rel_codes.device dtype, device = rel_codes.dtype, rel_codes.device
return decode_single(rel_codes, boxes, self.weights, self.bbox_xform_clip) return decode_single(rel_codes, boxes, self.weights, self.bbox_xform_clip)
CoderType = TypeVar('CoderType', bound=BoxCoderND)
...@@ -6,6 +6,8 @@ from typing import List, Tuple, Dict, Any, Optional, Union ...@@ -6,6 +6,8 @@ from typing import List, Tuple, Dict, Any, Optional, Union
from nndet.models.abstract import AbstractModel from nndet.models.abstract import AbstractModel
from nndet.core import boxes as box_utils from nndet.core import boxes as box_utils
from nndet.models.encoder.abstract import EncoderType
from nndet.models.decoder.base import DecoderType
from nndet.models.heads.segmenter import SegmenterType from nndet.models.heads.segmenter import SegmenterType
from nndet.models.heads.comb import HeadType from nndet.models.heads.comb import HeadType
...@@ -18,7 +20,7 @@ class BaseRetinaNet(AbstractModel): ...@@ -18,7 +20,7 @@ class BaseRetinaNet(AbstractModel):
decoder: DecoderType, decoder: DecoderType,
head: HeadType, head: HeadType,
num_classes: int, num_classes: int,
anchor_generator: AnchorType, anchor_generator: box_utils.AnchorGenerator,
matcher: box_utils.MatcherType, matcher: box_utils.MatcherType,
decoder_levels: tuple = (2, 3, 4, 5), decoder_levels: tuple = (2, 3, 4, 5),
# post-processing # post-processing
......
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Dict, Tuple, Any, Optional
import torch
from abc import abstractmethod
from torch import Tensor
class AbstractModel(torch.nn.Module):
@classmethod
@abstractmethod
def from_config_plan(cls,
model_cfg: dict,
plan_arch: dict,
plan_anchors: dict,
log_num_anchors: str = None,
**kwargs,
):
raise NotImplementedError
@abstractmethod
def train_step(self,
images: Tensor,
targets: dict,
evaluation: bool,
batch_num: int,
) -> Tuple[Dict[str, torch.Tensor], Optional[Dict]]:
"""
Perform a single training step
Args:
images: images to process
targets: labels for training
evaluation (bool): compute final predictions which should be used for metric evaluation
batch_num (int): batch index inside epoch
Returns:
torch.Tensor: final loss for back propagation
Dict: predictions for metric calculation
Dict[str, torch.Tensor]: scalars for logging (e.g. individual loss components)
"""
raise NotImplementedError
@abstractmethod
def inference_step(self,
images: Tensor,
*args,
**kwargs,
) -> Dict[str, Any]:
"""
Perform a single training step
Args:
images: images to process
*args: positional arguments
**kwargs: keyword arguments
Returns:
Dict: predictions for metric calculation
"""
raise NotImplementedError
from nndet.models.blocks.basic import AbstractBlock, StackedConvBlock, \
StackedResidualBlock, StackedConvBlock2
from nndet.models.blocks.res import ResBasic, ResBottleneck
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from abc import abstractmethod
from typing import Sequence, Callable, Union, Tuple
from nndet.models.conv import NdParam
from nndet.models.blocks.res import ResBasic
class AbstractBlock(nn.Module):
def __init__(self, out_channels: int, **kwargs):
"""
Basic building block of the encoder
"""
super().__init__(**kwargs)
self.out_channels = out_channels
def get_output_channels(self) -> int:
"""
Determine number of output channels of block
Returns:
int: number of output channels
"""
return self.out_channels
class StackedBlock(AbstractBlock):
expansion = 2
def __init__(self,
conv: Callable[[], nn.Module],
in_channels: int,
conv_kernel: NdParam,
stride: NdParam = None,
out_channels: int = None,
max_out_channels: int = None,
num_blocks: int = 1,
**kwargs):
"""
Plain stack of convolutions. Strides > 1 are applied at the beginning
by a strided convolution and the first convolution raises the number of
channels to :param:`out_channels`.
Args:
conv: conv generator to use for internal convolutions
in_channels: number of input channels
conv_kernel: kernel size of convolution
stride: Stride of first convolution. If None stride=1 will be used.
Defaults to None.
out_channels: If given, then number of output channels will be set
to this value. Otherwise the number of the input channels are
doubled. Defaults to None.
max_out_channels: Maximum number of output channels.
Defaults to None.
num_blocks: Number of blocks. Defaults to 1.
Raises:
ValueError: raise if given output channels are larger than max
output channels
"""
super().__init__(out_channels=None) # out_channels will be overwritten later
if (out_channels is not None and
max_out_channels is not None and
out_channels > max_out_channels):
raise ValueError("Output channels can not be larger"
"than max output channels")
if out_channels is None:
out_channels = in_channels * self.expansion
if max_out_channels is not None and out_channels > max_out_channels:
out_channels = max_out_channels
if stride is None:
stride = 1
if not isinstance(conv_kernel, Sequence):
conv_kernel = [conv_kernel] * conv.dim
padding = tuple([(i - 1) // 2 for i in conv_kernel])
_convs = []
_convs.append(self.build_block(
conv=conv, in_channels=in_channels, out_channels=out_channels,
kernel_size=conv_kernel, stride=stride, padding=padding, **kwargs))
for _ in range(num_blocks - 1):
_convs.append(self.build_block(
conv=conv, in_channels=out_channels, out_channels=out_channels,
kernel_size=conv_kernel, stride=1, padding=padding, **kwargs))
self.convs = nn.Sequential(*_convs)
self.out_channels = out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward tensor
Returns:
torch.Tensor: output tensor
"""
return self.convs(x)
@abstractmethod
def build_block(self, conv: Callable[[], nn.Module],
in_channels: int, out_channels: int,
kernel_size: NdParam,
stride: NdParam,
padding: NdParam,
) -> nn.Module:
raise NotImplementedError
class StackedConvBlock2(StackedBlock):
def build_block(self, conv: Callable, in_channels: int,
out_channels: int, kernel_size: NdParam,
stride: NdParam, padding: NdParam,
**kwargs) -> nn.Module:
"""
Build 2 consequtive convolutions
Args:
conv: generator for convolutions
in_channels: number of input channels
out_channels: number of output channels
kernel_size: kernel size oh convolutions
stride: stride of first convolution
padding: padding of convolutions
Returns:
nn.Module: stacked convolutions
"""
return torch.nn.Sequential(
conv(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, **kwargs),
conv(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=1, padding=padding, **kwargs),
)
class StackedConvBlock3(StackedBlock):
def build_block(self, conv: Callable, in_channels: int,
out_channels: int, kernel_size: NdParam,
stride: NdParam, padding: NdParam,
**kwargs) -> nn.Module:
"""
Build 2 consequtive convolutions
Args:
conv: generator for convolutions
in_channels: number of input channels
out_channels: number of output channels
kernel_size: kernel size oh convolutions
stride: stride of first convolution
padding: padding of convolutions
Returns:
nn.Module: stacked convolutions
"""
return torch.nn.Sequential(
conv(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, **kwargs),
conv(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=1, padding=padding, **kwargs),
conv(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=1, padding=padding, **kwargs),
)
class StackedResidualBlock(StackedBlock):
def build_block(self, conv: Callable[[], nn.Module], in_channels: int,
out_channels: int, kernel_size: NdParam,
stride: NdParam, padding: NdParam,
**kwargs) -> nn.Module:
"""
Build Residual Block
Args:
conv: generator for convolutions
in_channels: number of input channels
out_channels: number of output channels
kernel_size: kernel size oh convolutions
stride: stride of first convolution
padding: padding of convolutions
Returns:
nn.Module: stacked convolutions
"""
return ResBasic(conv=conv, in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, **kwargs)
class StackedConvBlock(AbstractBlock):
expansion = 2
def __init__(self,
conv: Callable[[], nn.Module],
in_channels: int,
conv_kernel: Union[Tuple[int], int],
stride: Union[Tuple[int], int] = None,
out_channels: int = None,
max_out_channels: int = None,
num_blocks: int = 2,
**kwargs):
"""
Plain stack of convolutions. Strides > 1 are applied at the beginning
by a strided convolution and the first convolution raises the number of
channels to :param:`out_channels`.
Args:
conv: conv generator to use for internal convolutions
in_channels: number of input channels
conv_kernel: kernel size of convolution
stride: Stride of first convolution. If None stride=1 will be used.
Defaults to None.
out_channels: If given, then number of output channels will be set
to this value. Otherwise the number of the input channels are
doubled. Defaults to None.
max_out_channels: Maximum number of output channels.
Defaults to None.
num_blocks: Number of convolutions. Defaults to 2.
Raises:
ValueError: raise if given output channels are larger than max
output channels
"""
super().__init__(out_channels=None) # out_channels will be overwritten later
if (out_channels is not None and
max_out_channels is not None and
out_channels > max_out_channels):
raise ValueError("Output channels can not be larger"
"than max output channels")
if out_channels is None:
out_channels = in_channels * self.expansion
if max_out_channels is not None and out_channels > max_out_channels:
out_channels = max_out_channels
if stride is None:
stride = 1
if not isinstance(conv_kernel, Sequence):
conv_kernel = [conv_kernel] * conv.dim
padding = tuple([(i - 1) // 2 for i in conv_kernel])
_convs = []
_convs.append(conv(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv_kernel,
stride=stride,
padding=padding,
**kwargs))
for _ in range(num_blocks - 1):
_convs.append(conv(in_channels=out_channels,
out_channels=out_channels,
kernel_size=conv_kernel,
stride=1,
padding=padding,
**kwargs))
self.convs = nn.Sequential(*_convs)
self.out_channels = out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward tensor
Returns:
torch.Tensor: output tensor
"""
return self.convs(x)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from typing import Sequence, Callable, Optional
from functools import reduce
from loguru import logger
from nndet.models.conv import nd_pool
from nndet.models.conv import NdParam
class ResBasic(nn.Module):
def __init__(self,
conv: Callable,
in_channels: int,
out_channels: int,
kernel_size: NdParam,
stride: NdParam,
padding: NdParam,
attention: Optional[nn.Module] = None,
):
"""
Build a plan residual block
Zero init norm according to https://arxiv.org/abs/1706.02677
Avg pool in downsampling path https://arxiv.org/pdf/1812.01187.pdf
Args:
conv: generator for convolutions
in_channels: number of input channels
out_channels: number of output channels
kernel_size: kernel size oh convolutions
stride: stride of first convolution
padding: padding of convolutions
attention: additional attention layer applied after convolutions
"""
super().__init__()
logger.warning("ResidualBlock uses normal relu! This might not be "
"desired if conv uses a different non linearity")
self.conv1 = conv(in_channels, out_channels, kernel_size=kernel_size,
padding=padding, stride=stride)
self.conv2 = conv(out_channels, out_channels, kernel_size=kernel_size,
padding=padding, relu=None)
self.relu = nn.ReLU(inplace=True)
stride_prod = (reduce((lambda x, y: x * y), stride)
if isinstance(stride, Sequence) else stride)
if stride_prod > 1:
self.shortcut = nn.Sequential(
nd_pool("Avg", dim=conv.dim, kernel_size=stride, stride=stride),
conv(in_channels, out_channels, kernel_size=1, relu=None),
)
else:
self.shortcut = None
self.attention = attention
self.init_weights()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward input
Args:
x (torch.Tensor) : input tensor
Returns:
torch.Tensor: output tensor
"""
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.attention:
out = self.attention(out)
if self.shortcut:
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
def init_weights(self) -> None:
try:
torch.nn.init.zeros_(self.conv2.norm.weight)
except:
logger.info(f"Zero init of last norm layer {self.conv2.norm} failed")
class ResBottleneck(nn.Module):
def __init__(self,
conv: Callable,
in_channels: int,
internal_channels: int,
kernel_size: NdParam,
stride: NdParam,
padding: NdParam,
expansion: int = 1,
attention: Optional[nn.Module] = None,
):
"""
Build a bottleneck residual block
Zero init norm according to https://arxiv.org/abs/1706.02677
Avg pool in downsampling path https://arxiv.org/pdf/1812.01187.pdf
in_channels -> internal_channels -> internal_channels * expansion
Args:
conv: generator for convolutions
in_channels: number of input channels
internal_channels: number of internal channels to use.
The number of output channels will be
internal_channels * expansion
kernel_size: kernel size oh convolutions
stride: stride of first convolution
padding: padding of convolutions
expansion: expansion for last conv block. Default expansion
is one to be compatible with modular encoder! Original
implementation uses expansion=4.
attention: additional attention layer applied after convolutions
"""
super().__init__()
logger.warning("ResidualBlock uses normal relu! This might not be "
"desired if conv uses a different non linearity")
out_channels = internal_channels * expansion
self.conv1 = conv(in_channels, internal_channels,
kernel_size=1, padding=0, stride=1,
)
self.conv2 = conv(internal_channels, internal_channels,
kernel_size=kernel_size, padding=padding, stride=stride,
)
self.conv3 = conv(internal_channels, out_channels,
kernel_size=1, padding=0, relu=None, stride=1,
)
self.relu = nn.ReLU(inplace=True)
# downsampling path
stride_prod = (reduce((lambda x, y: x * y), stride)
if isinstance(stride, Sequence) else stride)
if stride_prod > 1:
self.shortcut = nn.Sequential(
nd_pool("Avg", dim=conv.dim, kernel_size=stride, stride=stride),
conv(in_channels, out_channels, kernel_size=1, relu=None),
)
else:
self.shortcut = None
self.attention = attention
self.init_weights()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward input
Args:
x (torch.Tensor) : input tensor
Returns:
torch.Tensor: output tensor
"""
residual = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.attention:
out = self.attention(out)
if self.shortcut:
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
def init_weights(self) -> None:
try:
torch.nn.init.zeros_(self.conv2.norm.weight)
except:
logger.info(f"Zero init of last norm layer {self.conv2.norm} failed")
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from nndet.models.conv import nd_pool, nd_conv
class SELayer(nn.Module):
def __init__(self,
dim: int,
in_channels: int,
reduction: int = 16,
):
"""
Squeeze and Excitation Layer
https://arxiv.org/abs/1709.01507
Args
dim: number of spatial dimensions
in_channels: number of input channels
reduction: channel reduction for internal computations
"""
super(SELayer, self).__init__()
self.pool = nd_pool("AdaptiveAvg", dim, 1)
self.fc = nn.Sequential(
nd_conv(dim, in_channels, in_channels // reduction,
kernel_size=1, stride=1, bias=False),
nn.ReLU(inplace=True),
nd_conv(dim, in_channels // reduction, in_channels,
kernel_size=1, stride=1, bias=False),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.pool(x)
y = self.fc(y)
return x * y
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from typing import Union, Callable, Any, Optional, Tuple, Sequence, Type
from nndet.models.initializer import InitWeights_He
from nndet.models.layers.norm import GroupNorm
NdParam = Union[int, Tuple[int, int], Tuple[int, int, int]]
class Generator:
def __init__(self, conv_cls, dim: int):
"""
Factory helper which saves the conv class and dimension to generate objects
Args:
conv_cls (callable): class of convolution
dim (int): number of spatial dimensions (in general 2 or 3)
"""
self.dim = dim
self.conv_cls = conv_cls
def __call__(self, *args, **kwargs) -> Any:
"""
Create object
Args:
*args: passed to object
**kwargs: passed to object
Returns:
Any
"""
return self.conv_cls(self.dim, *args, **kwargs)
class BaseConvNormAct(torch.nn.Sequential):
def __init__(self,
dim: int,
in_channels: int,
out_channels: int,
norm: Optional[Union[Callable[..., Type[nn.Module]], str]],
act: Optional[Union[Callable[..., Type[nn.Module]], str]],
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
dilation: Union[int, tuple] = 1,
groups: int = 1,
bias: bool = None,
transposed: bool = False,
norm_kwargs: Optional[dict] = None,
act_inplace: Optional[bool] = None,
act_kwargs: Optional[dict] = None,
initializer: Callable[[nn.Module], None] = None,
):
"""
Baseclass for default ordering:
conv -> norm -> activation
Args
dim: number of dimensions the convolution should be chosen for
in_channels: input channels
out_channels: output_channels
norm: type of normalization. If None, no normalization will be applied
kernel_size: size of convolution kernel
act: class of non linearity; if None no actication is used.
stride: convolution stride
padding: padding value
(if input or output padding depends on whether the convolution
is transposed or not)
dilation: convolution dilation
groups: number of convolution groups
bias: whether to include bias or not
If None, the bias will be determined dynamicaly: False
if a normalization follows otherwise True
transposed: whether the convolution should be transposed or not
norm_kwargs: keyword arguments for normalization layer
act_inplace: whether to perform activation inplce or not
If None, inplace will be determined dynamicaly: True
if a normalization follows otherwise False
act_kwargs: keyword arguments for non linearity layer.
initializer: initilize weights
"""
super().__init__()
# process optional arguments
norm_kwargs = {} if norm_kwargs is None else norm_kwargs
act_kwargs = {} if act_kwargs is None else act_kwargs
if "inplace" in act_kwargs:
raise ValueError("Use keyword argument to en-/disable inplace activations")
if act_inplace is None:
act_inplace = bool(norm is not None)
act_kwargs["inplace"] = act_inplace
# process dynamic values
bias = bool(norm is None) if bias is None else bias
conv = nd_conv(dim=dim,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
transposed=transposed
)
self.add_module("conv", conv)
if norm is not None:
if isinstance(norm, str):
_norm = nd_norm(norm, dim, out_channels, **norm_kwargs)
else:
_norm = norm(dim, out_channels, **norm_kwargs)
self.add_module("norm", _norm)
if act is not None:
if isinstance(act, str):
_act = nd_act(act, dim, **act_kwargs)
else:
_act = act(**act_kwargs)
self.add_module("act", _act)
if initializer is not None:
self.apply(initializer)
class ConvInstanceRelu(BaseConvNormAct):
def __init__(self,
dim: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
dilation: Union[int, tuple] = 1,
groups: int = 1,
bias: bool = None,
transposed: bool = False,
add_norm: bool = True,
add_act: bool = True,
act_inplace: Optional[bool] = None,
norm_eps: float = 1e-5,
norm_affine: bool = True,
initializer: Callable[[nn.Module], None] = None,
):
"""
Baseclass for default ordering:
conv -> norm -> activation
Args
dim: number of dimensions the convolution should be chosen for
in_channels: input channels
out_channels: output_channels
norm: type of normalization. If None, no normalization will be applied
kernel_size: size of convolution kernel
act: class of non linearity; if None no actication is used.
stride: convolution stride
padding: padding value
(if input or output padding depends on whether the convolution
is transposed or not)
dilation: convolution dilation
groups: number of convolution groups
bias: whether to include bias or not
If None the bias will be determined dynamicaly: False
if a normalization follows otherwise True
transposed: whether the convolution should be transposed or not
add_norm: add normalisation layer to conv block
add_act: add activation layer to conv block
act_inplace: whether to perform activation inplce or not
If None, inplace will be determined dynamicaly: True
if a normalization follows otherwise False
norm_eps: instance norm eps (see pytorch for more info)
norm_affine: instance affine parameter (see pytorch for more info)
initializer: initilize weights
"""
norm = "Instance" if add_norm else None
act = "ReLU" if add_act else None
super().__init__(
dim=dim,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
transposed=transposed,
norm=norm,
act=act,
norm_kwargs={
"eps": norm_eps,
"affine": norm_affine,
},
act_inplace=act_inplace,
initializer=initializer,
)
class ConvGroupRelu(BaseConvNormAct):
def __init__(self,
dim: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
dilation: Union[int, tuple] = 1,
groups: int = 1,
bias: bool = None,
transposed: bool = False,
add_norm: bool = True,
add_act: bool = True,
act_inplace: Optional[bool] = None,
norm_eps: float = 1e-5,
norm_affine: bool = True,
norm_channels_per_group: int = 16,
initializer: Callable[[nn.Module], None] = None,
):
"""
Baseclass for default ordering:
conv -> norm -> activation
Args
dim: number of dimensions the convolution should be chosen for
in_channels: input channels
out_channels: output_channels
norm: type of normalization. If None, no normalization will be applied
kernel_size: size of convolution kernel
act: class of non linearity; if None no actication is used.
stride: convolution stride
padding: padding value
(if input or output padding depends on whether the convolution
is transposed or not)
dilation: convolution dilation
groups: number of convolution groups
bias: whether to include bias or not
If None the bias will be determined dynamicaly: False
if a normalization follows otherwise True
transposed: whether the convolution should be transposed or not
add_norm: add normalisation layer to conv block
add_act: add activation layer to conv block
act_inplace: whether to perform activation inplce or not
If None, inplace will be determined dynamicaly: True
if a normalization follows otherwise False
norm_eps: instance norm eps (see pytorch for more info)
norm_affine: instance affine parameter (see pytorch for more info)
norm_channels_per_group: channels per group for group norm
initializer: initilize weights
"""
norm = "Group" if add_norm else None
act = "ReLU" if add_act else None
super().__init__(
dim=dim,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
transposed=transposed,
norm=norm,
act=act,
norm_kwargs={
"eps": norm_eps,
"affine": norm_affine,
"channels_per_group": norm_channels_per_group,
},
act_inplace=act_inplace,
initializer=initializer,
)
def nd_conv(dim: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
dilation: Union[int, tuple] = 1,
groups: int = 1,
bias: bool = True,
transposed: bool = False,
**kwargs,
) -> torch.nn.Module:
"""
Convolution Wrapper to Switch accross dimensions and transposed by a
single argument
Args
n_dim (int): number of dimensions the convolution should be chosen for
in_channels (int): input channels
out_channels (int): output_channels
kernel_size (int or Iterable): size of convolution kernel
stride (int or Iterable): convolution stride
padding (int or Iterable): padding value
(if input or output padding depends on whether the convolution
is transposed or not)
dilation (int or Iterable): convolution dilation
groups (int): number of convolution groups
bias (bool): whether to include bias or not
transposed (bool): whether the convolution should be transposed or not
Returns:
torch.nn.Module: generated module
See Also
Torch Convolutions:
* :class:`torch.nn.Conv1d`
* :class:`torch.nn.Conv2d`
* :class:`torch.nn.Conv3d`
* :class:`torch.nn.ConvTranspose1d`
* :class:`torch.nn.ConvTranspose2d`
* :class:`torch.nn.ConvTranspose3d`
"""
if transposed:
transposed_str = "Transpose"
else:
transposed_str = ""
conv_cls = getattr(torch.nn, f"Conv{transposed_str}{dim}d")
return conv_cls(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, **kwargs)
def nd_pool(pooling_type: str, dim: int, *args, **kwargs) -> torch.nn.Module:
"""
Wrapper to switch between different pooling types and convolutions by a single argument
Args
pooling_type (str): Type of Pooling, case sensitive.
Supported values are
* ``Max``
* ``Avg``
* ``AdaptiveAvg``
* ``AdaptiveMax``
n_dim (int): number of dimensions
*args : positional arguments of the chosen pooling class
**kwargs : keyword arguments of the chosen pooling class
Returns:
torch.nn.Module: generated module
See Also
Torch Pooling Classes:
* :class:`torch.nn.MaxPool1d`
* :class:`torch.nn.MaxPool2d`
* :class:`torch.nn.MaxPool3d`
* :class:`torch.nn.AvgPool1d`
* :class:`torch.nn.AvgPool2d`
* :class:`torch.nn.AvgPool3d`
* :class:`torch.nn.AdaptiveMaxPool1d`
* :class:`torch.nn.AdaptiveMaxPool2d`
* :class:`torch.nn.AdaptiveMaxPool3d`
* :class:`torch.nn.AdaptiveAvgPool1d`
* :class:`torch.nn.AdaptiveAvgPool2d`
* :class:`torch.nn.AdaptiveAvgPool3d`
"""
pool_cls = getattr(torch.nn, f"{pooling_type}Pool{dim}d")
return pool_cls(*args, **kwargs)
def nd_norm(norm_type: str, dim: int, *args, **kwargs) -> torch.nn.Module:
"""
Wrapper to switch between different types of normalization and
dimensions by a single argument
Args
norm_type (str): type of normalization, case sensitive.
Supported types are:
* ``Batch``
* ``Instance``
* ``LocalResponse``
* ``Group``
* ``Layer``
n_dim (int, None): dimension of normalization input; can be None if normalization
is dimension-agnostic (e.g. LayerNorm)
*args : positional arguments of chosen normalization class
**kwargs : keyword arguments of chosen normalization class
Returns
torch.nn.Module: generated module
See Also
Torch Normalizations:
* :class:`torch.nn.BatchNorm1d`
* :class:`torch.nn.BatchNorm2d`
* :class:`torch.nn.BatchNorm3d`
* :class:`torch.nn.InstanceNorm1d`
* :class:`torch.nn.InstanceNorm2d`
* :class:`torch.nn.InstanceNorm3d`
* :class:`torch.nn.LocalResponseNorm`
* :class:`nndet.models.layers.norm.GroupNorm`
"""
if dim is None:
dim_str = ""
else:
dim_str = str(dim)
if norm_type.lower() == "group":
norm_cls = GroupNorm
else:
norm_cls = getattr(torch.nn, f"{norm_type}Norm{dim_str}d")
return norm_cls(*args, **kwargs)
def nd_act(act_type: str, dim: int, *args, **kwargs) -> torch.nn.Module:
"""
Helper to search for activations by string
The dim parameter is ignored.
Searches in torch.nn for activatio.
Args:
act_type: name of activation layer to look up.
dim: ignored
Returns:
torch.nn.Module: activation module
"""
act_cls = getattr(torch.nn, f"{act_type}")
return act_cls(*args, **kwargs)
def nd_dropout(dim: int, p: float = 0.5, inplace: bool = False, **kwargs) -> torch.nn.Module:
"""
Generate 1,2,3 dimensional dropout
Args:
dim (int): number of dimensions
p (float): doupout probability
inplace (bool): apply operation inplace
**kwargs: passed to dropout
Returns:
torch.nn.Module: generated module
"""
dropout_cls = getattr(torch.nn, "Dropout%dd" % dim)
return dropout_cls(p=p, inplace=inplace, **kwargs)
def compute_padding_for_kernel(kernel_size: Union[int, Sequence[int]]) -> \
Union[int, Tuple[int, int], Tuple[int, int, int]]:
"""
Compute padding such that feature maps keep their size with stride 1
Args:
kernel_size: kernel size to compute padding for
Returns:
Union[int, Tuple[int, int], Tuple[int, int, int]]: computed padding
"""
if isinstance(kernel_size, Sequence):
padding = tuple([(i - 1) // 2 for i in kernel_size])
else:
padding = (kernel_size - 1) // 2
return padding
def conv_kwargs_helper(norm: bool, activation: bool):
"""
Helper to force disable normalization and activation in layers
which have those by default
Args:
norm: en-/disable normalization layer
activation: en-/disable activation layer
Returns:
dict: keyword arguments to pass to conv generator
"""
kwargs = {
"add_norm": norm,
"add_act": activation,
}
return kwargs
from nndet.models.decoder.fpn import FPN, UFPN, FPN2
This diff is collapsed.
from nndet.models.encoder.abstract import AbstractEncoder
from nndet.models.encoder.modular import Encoder
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from typing import List, Dict, Union, TypeVar
from abc import abstractmethod
__all__ = ["AbstractEncoder"]
class AbstractEncoder(nn.Module):
def __int__(self, **kwargs):
"""
Provides an abstract interface for backbone networks
"""
super().__init__(**kwargs)
@abstractmethod
def forward(self, x) -> List[torch.Tensor]:
"""
Forward input through network
Args
x (torch.tensor): input tensor
Returns
list: list with feature maps from multiple resolutions
"""
raise NotImplementedError
@abstractmethod
def get_channels(self) -> List[int]:
"""
Compute number of channels for each returned feature map
inside the forward pass
Returns
List[int]: list with number of channels corresponding to
returned feature maps
"""
raise NotImplementedError
@abstractmethod
def get_strides(self) -> List[Dict[str, Union[List[int], int]]]:
"""
Compute number backbone strides for 2d and 3d case and all options
of network
Returns
List[Dict[str, Union[List[int], int]]]: dict with 'xy' for 2d
stride and optional 'z' for 3d cases. List
describes stride at respective output level
"""
raise NotImplementedError
EncoderType = TypeVar('EncoderType', bound=AbstractEncoder)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from typing import Callable, Tuple, Sequence, Union, List, Dict, Optional
from nndet.models.encoder.abstract import AbstractEncoder
from nndet.models.blocks.basic import AbstractBlock
__all__ = ["Encoder"]
class Encoder(AbstractEncoder):
def __init__(self,
conv: Callable[[], nn.Module],
conv_kernels: Sequence[Union[Tuple[int], int]],
strides: Sequence[Union[Tuple[int], int]],
block_cls: AbstractBlock,
in_channels: int,
start_channels: int,
stage_kwargs: Sequence[dict] = None,
out_stages: Sequence[int] = None,
max_channels: int = None,
first_block_cls: Optional[AbstractBlock] = None,
):
"""
Build a modular encoder model with specified blocks
The Encoder consists of "stages" which (in general) represent one
resolution in the resolution pyramid. The first level alwasys has
full resolution.
Args:
conv: conv generator to use for internal convolutions
strides: strides for pooling layers. Should have one
element less than conv_kernels
conv_kernels: kernel sizes for convolutions
block_cls: generate a block of convolutions (
e.g. stacked residual blocks)
in_channels: number of input channels
start_channels: number of start channels
stage_kwargs: additional keyword arguments for stages.
Defaults to None.
out_stages: define which stages should be returned. If `None` all
stages will be returned.Defaults to None.
first_block_cls: generate a block of convolutions for the first stage
By default this equal the provided block_cls
"""
super().__init__()
self.num_stages = len(conv_kernels)
self.dim = conv.dim
if stage_kwargs is None:
stage_kwargs = [{}] * self.num_stages
elif isinstance(stage_kwargs, dict):
stage_kwargs = [stage_kwargs] * self.num_stages
assert len(stage_kwargs) == len(conv_kernels)
if out_stages is None:
self.out_stages = list(range(self.num_stages))
else:
self.out_stages = out_stages
if first_block_cls is None:
first_block_cls = block_cls
stages = []
self.out_channels = []
if isinstance(strides[0], int):
strides = [tuple([s] * self.dim) for s in strides]
self.strides = strides
for stage_id in range(self.num_stages):
if stage_id == 0:
_block = first_block_cls(
conv=conv,
in_channels=in_channels,
out_channels=start_channels,
conv_kernel=conv_kernels[stage_id],
stride=None,
max_out_channels=max_channels,
**stage_kwargs[stage_id],
)
else:
_block = block_cls(
conv=conv,
in_channels=in_channels,
out_channels=None,
conv_kernel=conv_kernels[stage_id],
stride=strides[stage_id - 1],
max_out_channels=max_channels,
**stage_kwargs[stage_id],
)
in_channels = _block.get_output_channels()
self.out_channels.append(in_channels)
stages.append(_block)
self.stages = torch.nn.ModuleList(stages)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Forward data through encoder
Args:
x: input data
Returns:
List[torch.Tensor]: list of output from stages defined by
param:`out_stages`
"""
outputs = []
for stage_id, module in enumerate(self.stages):
x = module(x)
if stage_id in self.out_stages:
outputs.append(x)
return outputs
def get_channels(self) -> List[int]:
"""
Compute number of channels for each returned feature map inside the forward pass
Returns
list: list with number of channels corresponding to returned feature maps
"""
out_channels = []
for stage_id in range(self.num_stages):
if stage_id in self.out_stages:
out_channels.append(self.out_channels[stage_id])
return out_channels
def get_strides(self) -> List[List[int]]:
"""
Compute number backbone strides for 2d and 3d case and all options of network
Returns
List[List[int]]: defines the absolute stride for each output
feature map with respect to input size
"""
out_strides = []
for stage_id in range(self.num_stages):
if stage_id == 0:
out_strides.append([1] * self.dim)
else:
new_stride = [prev_stride * pool_size for prev_stride, pool_size
in zip(out_strides[stage_id - 1], self.strides[stage_id - 1])]
out_strides.append(new_stride)
return out_strides
from nndet.models.heads.classifier import ClassifierType, Classifier
from nndet.models.heads.comb import HeadType, AbstractHead
from nndet.models.heads.regressor import RegressorType, Regressor
from nndet.models.heads.segmenter import SegmenterType, Segmenter
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import math
import torch.nn as nn
from typing import Optional, TypeVar
from torch import Tensor
from abc import abstractmethod
from loguru import logger
from nndet.losses.classification import (
AsymmetricFocalLossWithLogits,
FocalLossWithLogits,
BCEWithLogitsLossOneHot,
CrossEntropyLoss,
)
CONV_TYPES = (nn.Conv2d, nn.Conv3d)
class Classifier(nn.Module):
@abstractmethod
def compute_loss(self, pred_logits: Tensor, targets: Tensor, **kwargs) -> Tensor:
"""
Compute classification loss (cross entropy loss)
Args:
pred_logits (Tensor): predicted logits
targets (Tensor): classification targets
Returns:
Tensor: classification loss
"""
raise NotImplementedError
@abstractmethod
def box_logits_to_probs(self, box_logits: Tensor) -> Tensor:
"""
Convert bounding box logits to probabilities
Args:
box_logits (Tensor): bounding box logits [N, C], C=number of classes
Returns:
Tensor: probabilities
"""
raise NotImplementedError
class BaseClassifier(Classifier):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
num_classes: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
**kwargs
):
"""
Base class to build classifier heads with typical conv structure
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
num_classes: number of foreground classes
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
classifier
num_convs: number of convolutions
input_conv -> num_convs -> output_convs
add_norm: en-/disable normalization layers in internal layers
kwargs: keyword arguments passed to first and internal convolutions
Notes:
`self.loss` needs to be overwritten in subclasses
`self.logits_convert_fn` needs to be overwritten in subclasses
"""
super().__init__()
self.dim = conv.dim
self.num_levels = num_levels
self.num_convs = num_convs
self.num_classes = num_classes
self.anchors_per_pos = anchors_per_pos
self.in_channels = in_channels
self.internal_channels = internal_channels
self.conv_internal = self.build_conv_internal(conv, add_norm=add_norm, **kwargs)
self.conv_out = self.build_conv_out(conv)
self.loss: Optional[nn.Module] = None
self.logits_convert_fn: Optional[nn.Module] = None
self.init_weights()
def build_conv_internal(self, conv, **kwargs):
"""
Build internal convolutions
"""
_conv_internal = nn.Sequential()
_conv_internal.add_module(
name="c_in",
module=conv(
self.in_channels,
self.internal_channels,
kernel_size=3,
stride=1,
padding=1,
**kwargs,
))
for i in range(self.num_convs):
_conv_internal.add_module(
name=f"c_internal{i}",
module=conv(
self.internal_channels,
self.internal_channels,
kernel_size=3,
stride=1,
padding=1,
**kwargs,
))
return _conv_internal
def build_conv_out(self, conv):
"""
Build final convolutions
"""
out_channels = self.num_classes * self.anchors_per_pos
return conv(
self.internal_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
add_norm=False,
add_act=False,
bias=True,
)
def forward(self,
x: torch.Tensor,
level: int,
**kwargs,
) -> torch.Tensor:
"""
Forward input
Args:
x (torch.Tensor): input feature map of size (N x C x Y x X x Z)
Returns:
torch.Tensor: classification logits for each anchor
(N x anchors x num_classes)
"""
class_logits = self.conv_out(self.conv_internal(x))
axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
class_logits = class_logits.permute(*axes)
class_logits = class_logits.contiguous()
class_logits = class_logits.view(x.size()[0], -1, self.num_classes)
return class_logits
def compute_loss(self, pred_logits: Tensor, targets: Tensor, **kwargs) -> Tensor:
"""
Base classifier with cross entropy loss (in general hard negative
example mining should be done before this)
Args:
pred_logits (Tensor): predicted logits
targets (Tensor): classification targets
Returns:
Tensor: classification loss
"""
return self.loss(pred_logits, targets.long(), **kwargs)
def box_logits_to_probs(self, box_logits: Tensor) -> Tensor:
"""
Convert bounding box logits to probabilities
Args:
box_logits (Tensor): bounding box logits [N, C]
N = number of anchors, C=number of foreground classes
Returns:
Tensor: probabilities
"""
return self.logits_convert_fn(box_logits)
def init_weights(self) -> None:
"""
Init weights with prior prob
"""
if self.prior_prob is not None:
logger.info(f"Init classifier weights: prior prob {self.prior_prob}")
for layer in self.modules():
if isinstance(layer, CONV_TYPES):
torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, 0)
# Use prior in model initialization to improve stability
bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
for layer in self.conv_out.modules():
if isinstance(layer, CONV_TYPES):
torch.nn.init.constant_(layer.bias, bias_value)
else:
logger.info("Init classifier weights: conv default")
class BCECLassifier(BaseClassifier):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
num_classes: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
prior_prob: Optional[float] = None,
weight: Optional[Tensor] = None,
reduction: str = "mean",
smoothing: float = 0.0,
loss_weight: float = 1.,
**kwargs
):
"""
Classifier Head with sigmoid based BCE loss computation and prio
prob weight init
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
num_classes: number of foreground classes
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
classifier
num_convs: number of convolutions
input_conv -> num_convs -> output_convs
add_norm: en-/disable normalization layers in internal layers
prior_prob: initialize final conv with given prior probability
weight: weight in BCEWithLogitsLoss (see pytorch for more info)
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
smoothing: label smoothing
loss_weight: scalar to balance multiple losses
kwargs: keyword arguments passed to first and internal convolutions
"""
self.prior_prob = prior_prob
super().__init__(
conv=conv,
in_channels=in_channels,
num_convs=num_convs,
add_norm=add_norm,
internal_channels=internal_channels,
num_classes=num_classes,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
**kwargs,
)
self.loss = BCEWithLogitsLossOneHot(
num_classes=num_classes,
weight=weight,
reduction=reduction,
smoothing=smoothing,
loss_weight=loss_weight,
)
self.logits_convert_fn = nn.Sigmoid()
class CEClassifier(BaseClassifier):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
num_classes: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
prior_prob: Optional[float] = None,
weight: Optional[Tensor] = None,
reduction: str = "mean",
loss_weight: float = 1.,
**kwargs
):
"""
Classifier Head with sigmoid based BCE loss computation and prio
prob weight init
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
num_classes: number of foreground classes
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
classifier
num_convs: number of convolutions
input_conv -> num_convs -> output_convs
add_norm: en-/disable normalization layers in internal layers
prior_prob: initialize final conv with given prior probability
weight: weight in cross entrpoy loss (see pytorch for more info)
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
loss_weight: scalar to balance multiple losses
kwargs: keyword arguments passed to first and internal convolutions
"""
self.prior_prob = prior_prob
super().__init__(
conv=conv,
in_channels=in_channels,
num_convs=num_convs,
add_norm=add_norm,
internal_channels=internal_channels,
num_classes=num_classes,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
**kwargs,
)
self.loss = CrossEntropyLoss(
weight=weight,
reduction=reduction,
loss_weight=loss_weight,
)
self.logits_convert_fn = nn.Softmax(dim=1)
def box_logits_to_probs(self, box_logits: Tensor) -> Tensor:
"""
Convert bounding box logits to probabilities
Args:
box_logits (Tensor): bounding box logits [N, C], C=number of classes
Returns:
Tensor: probabilities
"""
return self.logits_convert_fn(box_logits)[:, 1:]
class FocalClassifier(BaseClassifier):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
num_classes: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
prior_prob: Optional[float] = None,
gamma: float = 2,
alpha: float = -1,
reduction: str = "sum",
loss_weight: float = 1.,
**kwargs
):
"""
Classifier Head with sigmoid based BCE loss computation and
prio prob weight init
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
num_classes: number of foreground classes
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
classifier
num_convs: number of convolutions
input_conv -> num_convs -> output_convs
add_norm: en-/disable normalization layers in internal layers
prior_prob: initialize final conv with given prior probability
gamma: focal loss gamma
alpha: focal loss alpha
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
loss_weight: scalar to balance multiple losses
kwargs: keyword arguments passed to first and internal convolutions
"""
self.prior_prob = prior_prob
super().__init__(
conv=conv,
in_channels=in_channels,
num_convs=num_convs,
add_norm=add_norm,
internal_channels=internal_channels,
num_classes=num_classes,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
**kwargs,
)
self.loss = FocalLossWithLogits(
gamma=gamma,
alpha=alpha,
reduction=reduction,
loss_weight=loss_weight,
)
self.logits_convert_fn = nn.Sigmoid()
class AsymmetricFocalClassifier(FocalClassifier):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
num_classes: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
prior_prob: Optional[float] = None,
gamma: float = 2,
alpha: float = -1,
reduction: str = "sum",
loss_weight: float = 1.,
**kwargs
):
"""
Classifier Head with sigmoid based BCE loss computation and
prio prob weight init
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
num_classes: number of foreground classes
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
classifier
num_convs: number of convolutions
input_conv -> num_convs -> output_convs
add_norm: en-/disable normalization layers in internal layers
prior_prob: initialize final conv with given prior probability
gamma: focal loss gamma
alpha: focal loss alpha
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
loss_weight: scalar to balance multiple losses
kwargs: keyword arguments passed to first and internal convolutions
"""
self.prior_prob = prior_prob
super().__init__(
conv=conv,
in_channels=in_channels,
num_convs=num_convs,
add_norm=add_norm,
internal_channels=internal_channels,
num_classes=num_classes,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
**kwargs,
)
self.loss = AsymmetricFocalLossWithLogits(
gamma=gamma,
alpha=alpha,
reduction=reduction,
loss_weight=loss_weight,
)
self.logits_convert_fn = nn.Sigmoid()
class FullyConntectedBCECLassifier(BCECLassifier):
"""
BCE Classifier with 1x1 convs which act as fc
layers with shared weights across spatial locations
conv3(in, internal) -> num_convs x conv1(internal, internal) -> conv1(internal, out)
"""
def build_conv_internal(self, conv, **kwargs):
"""
Build internal convolutions
"""
_conv_internal = nn.Sequential()
_conv_internal.add_module(
name="c_in",
module=conv(
self.in_channels,
self.internal_channels,
kernel_size=3,
stride=1,
padding=1,
**kwargs,
))
for i in range(self.num_convs):
_conv_internal.add_module(
name=f"c_internal{i}",
module=conv(
self.internal_channels,
self.internal_channels,
kernel_size=1,
stride=1,
padding=0,
**kwargs,
))
return _conv_internal
def build_conv_out(self, conv):
"""
Build final convolutions
"""
out_channels = self.num_classes * self.anchors_per_pos
return conv(
self.internal_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
add_norm=False,
add_act=False,
bias=True,
)
ClassifierType = TypeVar('ClassifierType', bound=Classifier)
This diff is collapsed.
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple, Callable, TypeVar
from abc import abstractmethod
from loguru import logger
from nndet.detection.boxes import box_iou
from nndet.models.layers.scale import Scale
from torch import Tensor
from nndet.losses import SmoothL1Loss, GIoULoss
CONV_TYPES = (nn.Conv2d, nn.Conv3d)
class Regressor(nn.Module):
@abstractmethod
def compute_loss(self, pred_deltas: Tensor, target_deltas: Tensor, **kwargs) -> Tensor:
"""
Compute regression loss (l1 loss)
Args:
pred_deltas (Tensor): predicted bounding box deltas [N, dim * 2]
target_deltas (Tensor): target bounding box deltas [N, dim * 2]
Returns:
Tensor: loss
"""
raise NotImplementedError
class BaseRegressor(Regressor):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
learn_scale: bool = False,
**kwargs,
):
"""
Base class to build regressor heads with typical conv structure
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
regressor
num_convs: number of convolutions
in conv -> num convs -> final conv
add_norm: en-/disable normalization layers in internal layers
learn_scale: learn additional single scalar values per feature
pyramid level
kwargs: keyword arguments passed to first and internal convolutions
"""
super().__init__()
self.dim = conv.dim
self.num_levels = num_levels
self.num_convs = num_convs
self.learn_scale = learn_scale
self.anchors_per_pos = anchors_per_pos
self.in_channels = in_channels
self.internal_channels = internal_channels
self.conv_internal = self.build_conv_internal(conv, add_norm=add_norm, **kwargs)
self.conv_out = self.build_conv_out(conv)
if self.learn_scale:
self.scales = self.build_scales()
self.loss: Optional[nn.Module] = None
self.init_weights()
def build_conv_internal(self, conv, **kwargs):
"""
Build internal convolutions
"""
_conv_internal = nn.Sequential()
_conv_internal.add_module(
name="c_in",
module=conv(
self.in_channels,
self.internal_channels,
kernel_size=3,
stride=1,
padding=1,
**kwargs,
))
for i in range(self.num_convs):
_conv_internal.add_module(
name=f"c_internal{i}",
module=conv(
self.internal_channels,
self.internal_channels,
kernel_size=3,
stride=1,
padding=1,
**kwargs,
))
return _conv_internal
def build_conv_out(self, conv):
"""
Build final convolutions
"""
out_channels = self.anchors_per_pos * self.dim * 2
return conv(
self.internal_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
add_norm=False,
add_act=False,
bias=True,
)
def build_scales(self) -> nn.ModuleList:
"""
Build additionales scalar values per level
"""
logger.info("Learning level specific scalar in regressor")
return nn.ModuleList([Scale() for _ in range(self.num_levels)])
def forward(self, x: torch.Tensor, level: int, **kwargs) -> torch.Tensor:
"""
Forward input
Args:
x: input feature map of size [N x C x Y x X x Z]
Returns:
torch.Tensor: classification logits for each anchor
[N, n_anchors, dim*2]
"""
bb_logits = self.conv_out(self.conv_internal(x))
if self.learn_scale:
bb_logits = self.scales[level](bb_logits)
axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
bb_logits = bb_logits.permute(*axes)
bb_logits = bb_logits.contiguous()
bb_logits = bb_logits.view(x.size()[0], -1, self.dim * 2)
return bb_logits
def compute_loss(self,
pred_deltas: Tensor,
target_deltas: Tensor,
**kwargs,
) -> Tensor:
"""
Compute regression loss (l1 loss)
Args:
pred_deltas: predicted bounding box deltas [N, dim * 2]
target_deltas: target bounding box deltas [N, dim * 2]
Returns:
Tensor: loss
"""
return self.loss(pred_deltas, target_deltas, **kwargs)
def init_weights(self) -> None:
"""
Init weights with normal distribution (mean=0, std=0.01)
"""
logger.info("Overwriting regressor conv weight init")
for layer in self.modules():
if isinstance(layer, CONV_TYPES):
torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, 0)
class L1Regressor(BaseRegressor):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
beta: float = 1.,
reduction: Optional[str] = "sum",
loss_weight: float = 1.,
learn_scale: bool = False,
**kwargs,
):
"""
Build regressor heads with typical conv structure and smooth L1 loss
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
regressor
num_convs: number of convolutions
in conv -> num convs -> final conv
add_norm: en-/disable normalization layers in internal layers
beta: L1 to L2 change point.
For beta values < 1e-5, L1 loss is computed.
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
loss_weight: scalar to balance multiple losses
learn_scale: learn additional single scalar values per feature
pyramid level
kwargs: keyword arguments passed to first and internal convolutions
"""
super().__init__(
conv=conv,
in_channels=in_channels,
internal_channels=internal_channels,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
num_convs=num_convs,
add_norm=add_norm,
learn_scale=learn_scale,
**kwargs
)
self.loss = SmoothL1Loss(
beta=beta,
reduction=reduction,
loss_weight=loss_weight,
)
class GIoURegressor(BaseRegressor):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
reduction: Optional[str] = "sum",
loss_weight: float = 1.,
learn_scale: bool = False,
**kwargs,
):
"""
Build regressor heads with typical conv structure and generalized
IoU loss
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
regressor
num_convs: number of convolutions
in conv -> num convs -> final conv
add_norm: en-/disable normalization layers in internal layers
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
loss_weight: scalar to balance multiple losses
learn_scale: learn additional single scalar values per feature
pyramid level
kwargs: keyword arguments passed to first and internal convolutions
"""
super().__init__(
conv=conv,
in_channels=in_channels,
internal_channels=internal_channels,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
num_convs=num_convs,
add_norm=add_norm,
learn_scale=learn_scale,
**kwargs
)
self.loss = GIoULoss(
reduction=reduction,
loss_weight=loss_weight,
)
class IoUBranchGIoURegressor(GIoURegressor):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
learn_scale: bool = False,
reduction: Optional[str] = "sum",
loss_weight: float = 1.,
loss_weight_iou_branch: float = 1.,
iou_fn: Callable[[Tensor, Tensor], Tensor] = box_iou,
**kwargs,
):
"""
GIoU Box regression head with additional IoU prediction branch
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
regressor
num_convs: number of convolutions
in conv -> num convs -> final conv
add_norm: en-/disable normalization layers in internal layers
learn_scale: learn additional single scalar values per feature
pyramid level
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
loss_weight: scalar to balance multiple losses
loss_weight_iou_branch: weight of loss of IoU branch
iou_fn: iou function to compute targets for IoU branch
kwargs: keyword arguments passed to first and internal convolutions
"""
super().__init__(
conv=conv,
in_channels=in_channels,
internal_channels=internal_channels,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
num_convs=num_convs,
add_norm=add_norm,
learn_scale=learn_scale,
reduction=reduction,
loss_weight=loss_weight,
**kwargs
)
self.conv_iou_branch = self.build_conv_iou_branch(conv)
self.iou_branch_loss = nn.BCEWithLogitsLoss()
self.loss_weight_iou_branch = loss_weight_iou_branch
self.iou_fn = iou_fn
def build_conv_iou_branch(self, conv) -> nn.Module:
"""
Build IoU branch convs
"""
return conv(
self.internal_channels,
self.anchors_per_pos,
kernel_size=3,
stride=1,
padding=1,
add_norm=False,
add_act=False,
bias=True,
)
def forward(self, x: torch.Tensor, level: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward input
Args:
x (torch.Tensor): input feature map of size [N x C x Y x X x Z]
Returns:
torch.Tensor: classification logits for each anchor [N, n_anchors, dim*2]
"""
intermediate_features = self.conv_internal(x)
bb_logits = self.conv_out(intermediate_features)
iou_logits = self.conv_iou_branch(intermediate_features)
if self.learn_scale:
bb_logits = self.scales[level](bb_logits)
axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
bb_logits = bb_logits.permute(*axes).contiguous()
bb_logits = bb_logits.view(x.size()[0], -1, self.dim * 2)
iou_logits = iou_logits.permute(*axes).contiguous()
iou_logits = iou_logits.view(x.size()[0], -1)
return bb_logits, iou_logits
def compute_loss(self,
pred_boxes: Tensor,
target_boxes: Tensor,
pred_iou: Tensor,
) -> Tensor:
"""
Compute regression loss and IoU branch loss
Args:
pred_boxes: predicted bounding box deltas [N, dim * 2]
target_boxes: target bounding box deltas [N, dim * 2]
pred_iou: predicted IoU
Returns:
Tensor: loss
"""
reg_loss = self.loss(pred_boxes, target_boxes)
target_ious = self.iou_fn(pred_boxes, target_boxes).diag(diagonal=0)
iou_branch_loss = self.loss_weight_iou_branch * self.iou_branch_loss(pred_iou, target_ious)
return reg_loss + iou_branch_loss
RegressorType = TypeVar('RegressorType', bound=Regressor)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from torch import Tensor
from typing import Dict, List, Union, Sequence, Optional, Tuple, TypeVar
from nndet.models.conv import compute_padding_for_kernel, conv_kwargs_helper
from nndet.models.heads.comb import AbstractHead
from nndet.losses.segmentation import SoftDiceLoss, TopKLoss
from nndet.models.layers.interpolation import InterpolateToShapes
class Segmenter(AbstractHead):
def __init__(self,
seg_classes: int,
in_channels: Sequence[int],
decoder_levels: Sequence[int],
**kwargs,
):
"""
Abstract interface for segmentation head
Args:
seg_classes: number of foreground classes
(!! internally +1 added for background)!!)
in_channels: number of input channels at all decoder levels
decoder_levels: decoder levels used for detection
"""
super().__init__()
self.seg_classes = seg_classes + 1
self.in_channels = in_channels
self.decoder_levels = decoder_levels
class DiCESegmenter(Segmenter):
def __init__(self,
conv,
seg_classes: int,
in_channels: Sequence[int],
decoder_levels: Sequence[int],
internal_channels: Optional[int] = None,
num_internal: int = 0,
add_norm: bool = True,
add_act: bool= True,
kernel_size: Union[int, Sequence[int]] = 3,
alpha: float = 0.5,
ce_kwargs: Optional[dict] = None,
dice_kwargs: Optional[dict] = None,
**kwargs,
):
"""
Basic Segmentation Head with dice and CE loss
(num_internal x conv [kernel_size]) -> final conv [1x1]
Args:
conv: Convolution modules which handles a single layer
seg_classes: number of foreground classes
(!! internally +1 added for background)!!)
in_channels: number of input channels at all decoder levels
decoder_levels: decoder levels used for detection
internal_channels: number of channels of internal convolutions
num_internal: number of internal convolutions
add_norm: add normalization layers to internal convolutions
add_act: add activation layers to internal convolutions
kernel_size: kernel size of conv
alpha: weight dice and ce loss (alpha * ce + (1-alpha) * soft_dice)
ce_kwargs: keyword arguments passed to CE loss
dice_kwargs: keyword arguments passed to dice loss
"""
super().__init__(
seg_classes=seg_classes,
in_channels=in_channels,
decoder_levels=decoder_levels,
)
self.num_internal = num_internal
if internal_channels is None:
self.internal_channels = self.in_channels[0]
else:
self.internal_channels = internal_channels
self.conv_out = self.build_conv_out(conv)
self.conv_intermediate = self.build_conv_internal(
conv,
kernel_size=kernel_size,
add_norm=add_norm,
add_act=add_act,
**kwargs,
)
if dice_kwargs is None:
dice_kwargs = {}
dice_kwargs.setdefault("smooth_nom", 1e-5)
dice_kwargs.setdefault("smooth_denom", 1e-5)
dice_kwargs.setdefault("do_bg", False)
self.dice_loss = SoftDiceLoss(nonlin=torch.nn.Softmax(dim=1), **dice_kwargs)
if ce_kwargs is None:
ce_kwargs = {}
self.ce_loss = torch.nn.CrossEntropyLoss(**ce_kwargs)
self.logits_convert_fn = nn.Softmax(dim=1)
self.alpha = alpha
def build_conv_out(self, conv) -> nn.Module:
"""
Build output convolution
"""
_intermediate_channels = self.internal_channels if self.num_internal > 0 else self.in_channels[0]
return conv(
_intermediate_channels,
self.seg_classes,
kernel_size=1,
padding=0,
add_norm=None,
add_act=None,
bias=True,
)
def build_conv_internal(self,
conv,
kernel_size: Union[int, Tuple[int]],
add_norm: bool,
add_act: bool,
**kwargs,
) -> Optional[nn.Module]:
"""
Buld internal convolutions
"""
padding = compute_padding_for_kernel(kernel_size)
if self.num_internal > 0:
_intermediate = torch.nn.Sequential()
for i in range(self.num_internal):
_intermediate.add_module(
f"c_intermediate{i}",
conv(
self.in_channels if i == 0 else self.internal_channels,
self.internal_channels,
kernel_size=kernel_size,
padding=padding,
stride=1,
add_norm=add_norm,
add_act=add_act,
**kwargs
)
)
else:
_intermediate = None
return _intermediate
def forward(self,
x: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Forward pass
Args:
x: all features produced by decoder. Largest to smallest.
Returns:
torch.Tensor: result
"""
x = x[0]
if self.conv_intermediate is not None:
x = self.conv_intermediate(x)
return {"seg_logits": self.conv_out(x)}
def compute_loss(self,
pred_seg: Dict[str, torch.Tensor],
target: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Compute weighted dice and cross entropy loss
Args:
pred_seg: segmentation predictions
`seg_logits`: predicted logits
target: ground truth segmentation of top layer
Returns:
Dict[str, torch.Tensor]: computed loss (contained in key seg)
"""
seg_logits = pred_seg["seg_logits"]
return {
"seg_ce": self.alpha * self.ce_loss(seg_logits, target.long()),
"seg_dice": (1 - self.alpha) * self.dice_loss(seg_logits, target),
}
def postprocess_for_inference(self,
prediction: Dict[str, torch.Tensor],
*args, **kwargs,
) -> Dict[str, torch.Tensor]:
"""
Postprocess predictions for inference e.g. convert logits to probs
Args:
Dict[str, torch.Tensor]: predictions from this head
`seg_logits`: predicted logits
Returns:
Dict[str, torch.Tensor]: postprocessed predictions
`pred_seg`: predicted probabilities [N, C, dims]
"""
return {"pred_seg": self.logits_convert_fn(prediction["seg_logits"])}
class DiCESegmenterFgBg(DiCESegmenter):
def __init__(self,
conv,
seg_classes: int,
in_channels: Sequence[int],
decoder_levels: Sequence[int],
internal_channels: Optional[int] = None,
num_internal: int = 0,
add_norm: bool = True,
add_act: bool= True,
kernel_size: Union[int, Sequence[int]] = 3,
alpha: float = 0.5,
**kwargs,
):
"""
Basic Segmentation Head with dice and CE loss which only
differentiates foreground and background
(num_internal x conv [kernel_size]) -> final conv [1x1]
Args:
conv: Convolution modules which handles a single layer
seg_classes: ignored!
in_channels: number of input channels at all decoder levels
decoder_levels: decoder levels used for detection
internal_channels: number of channels of internal convolutions
num_internal: number of internal convolutions
add_norm: add normalization layers to internal convolutions
add_act: add activation layers to internal convolutions
kernel_size: kernel size of conv
alpha: weight dice and ce loss (alpha * ce + (1-alpha) * soft_dice)
ce_kwargs: keyword arguments passed to CE loss
dice_kwargs: keyword arguments passed to dice loss
Warnings:
If this class is used, the reportet dice scores during training
are wrong if multiple classes are present in the dataset.
"""
super().__init__(conv=conv,
in_channels=in_channels,
seg_classes=1,
decoder_levels=decoder_levels,
internal_channels=internal_channels,
num_internal=num_internal,
add_norm=add_norm,
add_act=add_act,
kernel_size=kernel_size,
alpha=alpha,
**kwargs,
)
def compute_loss(self,
pred_seg: Dict[str, torch.Tensor],
target: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Compute weighted dice and cross entropy loss
Args:
pred_seg: segmentation predictions
`seg_logits`: predicted logits
target: ground truth segmentation of top layer
Returns:
Dict[str, torch.Tensor]: computed loss (contained in key seg)
"""
target[target > 0] = 1
return super().compute_loss(pred_seg, target)
class DiceTopKSegmenter(DiCESegmenter):
def __init__(self,
conv,
seg_classes: int,
in_channels: Sequence[int],
decoder_levels: Sequence[int],
internal_channels: Optional[int] = None,
num_internal: int = 0,
add_norm: bool = True,
add_act: bool= True,
kernel_size: Union[int, Sequence[int]] = 3,
alpha: float = 0.5,
topk: float = 0.1,
**kwargs,
):
"""
Basic Segmentation Head with dice and TopK loss
(num_internal x conv [kernel_size]) -> final conv [1x1]
Args:
conv: Convolution modules which handles a single layer
seg_classes: number of foreground classes
(!! internally +1 added for background)!!)
in_channels: number of input channels at all decoder levels
decoder_levels: decoder levels used for detection
internal_channels: number of channels of internal convolutions
num_internal: number of internal convolutions
add_norm: add normalization layers to internal convolutions
add_act: add activation layers to internal convolutions
kernel_size: kernel size of conv
alpha: weight dice and ce loss (alpha * ce + (1-alpha) * soft_dice)
ce_kwargs: keyword arguments passed to CE loss
topk: percentage of all entries to use for loss computation
"""
super().__init__(conv=conv,
in_channels=in_channels,
seg_classes=seg_classes,
decoder_levels=decoder_levels,
internal_channels=internal_channels,
num_internal=num_internal,
add_norm=add_norm,
add_act=add_act,
kernel_size=kernel_size,
alpha=alpha,
ce_kwargs=None,
**kwargs,
)
self.ce_loss = TopKLoss(
topk=topk
)
class DiceTopKSegmenterFgBg(DiCESegmenterFgBg):
def __init__(self,
conv,
seg_classes: int,
in_channels: Sequence[int],
decoder_levels: Sequence[int],
internal_channels: Optional[int] = None,
num_internal: int = 0,
add_norm: bool = True,
add_act: bool= True,
kernel_size: Union[int, Sequence[int]] = 3,
alpha: float = 0.5,
topk: float = 0.1,
**kwargs,
):
"""
Basic Segmentation Head with dice and CE loss which only
differentiates foreground and background
(num_internal x conv [kernel_size]) -> final conv [1x1]
Args:
conv: Convolution modules which handles a single layer
seg_classes: ignored!
in_channels: number of input channels at all decoder levels
decoder_levels: decoder levels used for detection
internal_channels: number of channels of internal convolutions
num_internal: number of internal convolutions
add_norm: add normalization layers to internal convolutions
add_act: add activation layers to internal convolutions
kernel_size: kernel size of conv
alpha: weight dice and ce loss (alpha * ce + (1-alpha) * soft_dice)
ce_kwargs: keyword arguments passed to CE loss
topk: percentage of all entries to use for loss computation
Warnings:
If this class is used, the reportet dice scores during training
are wrong if multiple classes are present in the dataset.
"""
super().__init__(conv=conv,
in_channels=in_channels,
seg_classes=seg_classes,
decoder_levels=decoder_levels,
internal_channels=internal_channels,
num_internal=num_internal,
add_norm=add_norm,
add_act=add_act,
kernel_size=kernel_size,
alpha=alpha,
**kwargs,
)
self.ce_loss = TopKLoss(
topk=topk
)
class DeepSupervisionSegmenterFGBG(DiCESegmenterFgBg):
def __init__(self,
conv,
seg_classes: int,
in_channels: Sequence[int],
decoder_levels: Sequence[int],
internal_channels: Optional[int] = None,
num_internal: int = 0,
add_norm: bool = True,
add_act: bool= True,
kernel_size: Union[int, Sequence[int]] = 3,
alpha: float = 0.5,
dsv_weight: float = 1.,
**kwargs,
):
"""
Deep supervision segmenation which trains with CE and Dice
to differentitate foreground and background
(num_internal x conv [kernel_size]) -> final conv [1x1]
Args:
conv: Convolution modules which handles a single layer
seg_classes: ignored!
(!! internally +1 added for background)!!)
in_channels: number of input channels at all decoder levels
decoder_levels: decoder levels used for detection
internal_channels: number of channels of internal convolutions
num_internal: number of internal convolutions
add_norm: add normalization layers to internal convolutions
add_act: add activation layers to internal convolutions
kernel_size: kernel size of conv
alpha: weight dice and ce loss (alpha * ce + (1-alpha) * soft_dice)
ce_kwargs: keyword arguments passed to CE loss
dice_kwargs: keyword arguments passed to dice loss
dsv_weight: additional weight for dsv losses
"""
super().__init__(conv=conv,
in_channels=in_channels,
seg_classes=1,
decoder_levels=decoder_levels,
internal_channels=internal_channels,
num_internal=num_internal,
add_norm=add_norm,
add_act=add_act,
kernel_size=kernel_size,
alpha=alpha,
**kwargs,
)
assert len(self.decoder_levels) > 0
self.dsv_conv = conv(self.in_channels[-1],
2,
kernel_size=3,
padding=1,
add_norm=False,
add_act=False,
bias=True,
)
self.interpolator = InterpolateToShapes()
self.dsv_weight = dsv_weight
def forward(self,
x: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Forward pass
Args:
x: all features produced by decoder. Largest to smallest.
Returns:
torch.Tensor: result
"""
predictions = {}
if self.intermediate is not None:
predictions["seg_logits"] = self.conv_out(self.conv_intermediate(x[0]))
else:
predictions["seg_logits"] = self.conv_out(x[0])
for dl in self.decoder_levels:
predictions[f"dsv_logits_{dl}"] = self.dsv_conv(x[dl])
return predictions
def compute_loss(self,
pred_seg: Dict[str, torch.Tensor],
target: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Compute weighted dice and cross entropy loss
Args:
pred_seg: segmentation predictions
`seg_logits`: predicted logits
target: ground truth segmentation of top layer
Returns:
Dict[str, torch.Tensor]: computed loss (contained in key seg)
"""
target[target > 0] = 1
loss = self._compute_loss(pred_seg["seg_logits"], target)
preds_decoder_level = [pred_seg[f"dsv_logits_{dl}"] for dl in self.decoder_levels]
targets_interpolated = self.interpolator(preds_decoder_level, target)
for pred, target in zip(preds_decoder_level, targets_interpolated):
loss = loss + self.dsv_weight * self._compute_loss(pred, target)
return {"seg_loss": loss / (len(self.decoder_levels) + 1)}
def _compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return self.alpha * self.ce_loss(pred, target.long()) + \
(1 - self.alpha) * self.dice_loss(pred, target)
SegmenterType = TypeVar('SegmenterType', bound=Segmenter)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment