Commit 4560ce77 authored by mibaumgartner's avatar mibaumgartner
Browse files

models

parent 94d6ac20
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from torch import nn
class InitWeights_He(object):
def __init__(self,
neg_slope: float = 1e-2,
mode: str = "fan_in",
nonlinearity="leaky_relu",
):
"""
Init weights according to https://arxiv.org/abs/1502.01852
Args:
neg_slope (float, optional): the negative slope of the rectifier
used after this layer (only with 'leaky_relu').
Defaults to 1e-2.
mode: mode of `kaiming_normal_` mode
nonlinearity: name of non linear function. Recommended only with
relu and leaky relu
"""
self.neg_slope = neg_slope
def __call__(self, module: nn.Module):
"""
Apply weight init
Args:
module: module to initialize weights of (only inits wights of convs)
"""
if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
if module.bias is not None:
module.bias = nn.init.constant_(module.bias, 0)
from nndet.models.layers.interpolation import (
Interpolate,
InterpolateToShapes,
InterpolateToShape,
MaxPoolToShapes,
)
from nndet.models.layers.norm import GroupNorm
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn.functional as F
from typing import Union, Tuple, List
from torch import Tensor
__all__ = ["InterpolateToShapes", "InterpolateToShape", "Interpolate"]
class InterpolateToShapes(torch.nn.Module):
def __init__(self, mode: str = "nearest", align_corners: bool = None):
"""
Downsample target tensor to size of prediction feature maps
Args:
mode: algorithm used for upsampling: nearest, linear, bilinear,
bicubic, trilinear, area. Defaults to "nearest".
align_corners: Align corners points for interpolation. (see pytorch
for more info) Defaults to None.
See Also:
:func:`torch.nn.functional.interpolate`
Warnings:
Use nearest for segmentation, everything else will result in
wrong values.
"""
super().__init__()
self.mode = mode
self.align_corners = align_corners
def forward(self, preds: List[Tensor], target: Tensor) -> List[Tensor]:
"""
Interpolate target to match shape with predictions
Args:
preds: predictions to extract shape of
target: target to interpolate
Returns:
List[Tensor]: interpolated targets
"""
shapes = [tuple(pred.shape)[2:] for pred in preds]
squeeze_result = False
if target.ndim == preds[0].ndim - 1:
target = target.unsqueeze(dim=1)
squeeze_result = True
new_targets = [F.interpolate(
target, size=shape, mode=self.mode, align_corners=self.align_corners)
for shape in shapes]
if squeeze_result:
new_targets = [nt.squeeze(dim=1) for nt in new_targets]
return new_targets
class MaxPoolToShapes(torch.nn.Module):
def forward(self, preds: List[Tensor], target: Tensor) -> List[Tensor]:
"""
Pool target to match shape with predictions
Args:
preds: predictions to extract shape of
target: target to pool
Returns:
List[Tensor]: pooled targets
"""
dim = preds[0].ndim - 2
target_shape = list(target.shape)[-dim:]
pool = []
for pred in preds:
pred_shape = list(pred.shape)[-dim:]
pool.append(tuple([int(t / p) for t, p in zip(target_shape, pred_shape)]))
squeeze_result = False
if target.ndim == preds[0].ndim - 1:
target = target.unsqueeze(dim=1)
squeeze_result = True
fn = getattr(F, f"max_pool{dim}d")
new_targets = [fn(target, kernel_size=p, stride=p) for p in pool]
if squeeze_result:
new_targets = [nt.squeeze(dim=1) for nt in new_targets]
return new_targets
class InterpolateToShape(InterpolateToShapes):
"""
Interpolate predictions to target size
"""
def forward(self, preds: List[Tensor], target: Tensor) -> List[Tensor]:
"""
Interpolate predictions to match target
Args:
preds: predictions to extract shape of
target: target to interpolate
Returns:
List[Tensor]: interpolated targets
"""
shape = tuple(target.shape)[2:]
squeeze_result = False
if target.ndim == preds[0].ndim - 1:
target = target.unsqueeze(dim=1)
squeeze_result = True
new_targets = [F.interpolate(
pred, size=shape, mode=self.mode, align_corners=self.align_corners)
for pred in preds]
if squeeze_result:
new_targets = [nt.squeeze(dim=1) for nt in new_targets]
return new_targets
class Interpolate(torch.nn.Module):
def __init__(self, size: Union[int, Tuple[int]] = None,
scale_factor: Union[float, Tuple[float]] = None,
mode: str = "nearest", align_corners: bool = None):
"""
nn.Module for interpolation based on functional interpolation from
pytorch
Args:
size: output spatial size. Defaults to None.
scale_factor: multiplier for spatial size. Has to match input size
if it is a tuple. Defaults to None.
mode: algorithm used for upsampling: nearest, linear, bilinear,
bicubic, trilinear, aera. Defaults to "nearest".
align_corners: Align corners points for interpolation. (see pytorch
for more info) Defaults to None.
See Also:
:func:`torch.nn.functional.interpolate`
"""
super().__init__()
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x: Tensor) -> Tensor:
"""
Interpolate input batch
Args:
x: input tensor to interpolate
Returns:
Tensor: interpolated tensor
"""
return F.interpolate(
x, size=self.size, scale_factor=self.scale_factor,
mode=self.mode, align_corners=self.align_corners)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch.nn as nn
from typing import Optional
"""
Note: register new normalization layers in
nndet.training.optimizer.NORM_TYPES to exclude them from weight decay
"""
class GroupNorm(nn.GroupNorm):
def __init__(self, num_channels: int,
num_groups: Optional[int] = None,
channels_per_group: Optional[int] = None,
eps: float = 1e-05, affine: bool = True, **kwargs) -> None:
"""
PyTorch Group Norm (changed interface, num_channels at first position)
Args:
num_channels: number of input channels
num_groups: number of groups to separate channels. Mutually
exclusive with `channels_per_group`
channels_per_group: number of channels per group. Mutually exclusive
with `num_groups`
eps: value added to the denom for numerical stability. Defaults to 1e-05.
affine: Enable learnable per channel affine params. Defaults to True.
"""
if channels_per_group is not None:
if num_groups is not None:
raise ValueError("Can only use `channels_per_group` OR `num_groups` in GroupNorm")
num_groups = num_channels // channels_per_group
super().__init__(num_channels=num_channels,
num_groups=num_groups,
eps=eps, affine=affine, **kwargs)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
class Scale(nn.Module):
def __init__(self, scale: float = 1.):
"""
Layer to create a learnable scaling of feature maps
Args:
scale: initial value
"""
super().__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""
Args:
inp: input tensor
Returns:
Tensor: scaled tensor
"""
return inp * self.scale
def extra_repr(self) -> str:
return f"scale={self.scale.item()}"
...@@ -26,7 +26,7 @@ from typing import Callable, Hashable, Sequence, Dict, Any, Type ...@@ -26,7 +26,7 @@ from typing import Callable, Hashable, Sequence, Dict, Any, Type
import torch import torch
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from torchvision.models.detection.rpn import AnchorType, AnchorGenerator from torchvision.models.detection.rpn import AnchorGenerator
from nndet.utils.tensor import to_numpy from nndet.utils.tensor import to_numpy
from nndet.evaluator.det import BoxEvaluator from nndet.evaluator.det import BoxEvaluator
...@@ -43,7 +43,8 @@ from nndet.ptmodule.base_module import LightningBaseModuleSWA, LightningBaseModu ...@@ -43,7 +43,8 @@ from nndet.ptmodule.base_module import LightningBaseModuleSWA, LightningBaseModu
from nndet.models.conv import Generator, ConvInstanceRelu, ConvGroupRelu from nndet.models.conv import Generator, ConvInstanceRelu, ConvGroupRelu
from nndet.models.blocks.basic import StackedConvBlock2 from nndet.models.blocks.basic import StackedConvBlock2
from nndet.models.encoder.modular import EncoderType, Encoder from nndet.models.encoder.abstract import EncoderType
from nndet.models.encoder.modular import Encoder
from nndet.models.decoder.base import DecoderType, BaseUFPN, UFPNModular from nndet.models.decoder.base import DecoderType, BaseUFPN, UFPNModular
from nndet.models.heads.classifier import ClassifierType, CEClassifier from nndet.models.heads.classifier import ClassifierType, CEClassifier
from nndet.models.heads.regressor import RegressorType, L1Regressor from nndet.models.heads.regressor import RegressorType, L1Regressor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment