misc.py 7.62 KB
Newer Older
1
import warnings
2
3
from typing import Callable, List, Optional

4
import torch
5
from torch import Tensor
eellison's avatar
eellison committed
6

7
8
from ..utils import _log_api_usage_once

eellison's avatar
eellison committed
9

10
11
12
13
14
class Conv2d(torch.nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        warnings.warn(
            "torchvision.ops.misc.Conv2d is deprecated and will be "
15
16
17
            "removed in future versions, use torch.nn.Conv2d instead.",
            FutureWarning,
        )
18
19
20
21
22
23
24


class ConvTranspose2d(torch.nn.ConvTranspose2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        warnings.warn(
            "torchvision.ops.misc.ConvTranspose2d is deprecated and will be "
25
26
27
            "removed in future versions, use torch.nn.ConvTranspose2d instead.",
            FutureWarning,
        )
28
29
30
31
32
33
34


class BatchNorm2d(torch.nn.BatchNorm2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        warnings.warn(
            "torchvision.ops.misc.BatchNorm2d is deprecated and will be "
35
36
37
            "removed in future versions, use torch.nn.BatchNorm2d instead.",
            FutureWarning,
        )
38
39


40
interpolate = torch.nn.functional.interpolate
41
42
43


# This is not in nn
44
class FrozenBatchNorm2d(torch.nn.Module):
45
    """
46
47
48
49
50
    BatchNorm2d where the batch statistics and the affine parameters are fixed

    Args:
        num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
        eps (float): a value added to the denominator for numerical stability. Default: 1e-5
51
52
    """

53
54
    def __init__(
        self,
55
        num_features: int,
56
        eps: float = 1e-5,
57
        n: Optional[int] = None,
58
    ):
59
60
        # n=None for backward-compatibility
        if n is not None:
61
            warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning)
62
            num_features = n
63
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
64
        _log_api_usage_once(self)
65
66
67
68
69
        self.eps = eps
        self.register_buffer("weight", torch.ones(num_features))
        self.register_buffer("bias", torch.zeros(num_features))
        self.register_buffer("running_mean", torch.zeros(num_features))
        self.register_buffer("running_var", torch.ones(num_features))
70

71
72
73
74
75
76
77
78
79
80
    def _load_from_state_dict(
        self,
        state_dict: dict,
        prefix: str,
        local_metadata: dict,
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ):
81
        num_batches_tracked_key = prefix + "num_batches_tracked"
82
83
84
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

85
        super()._load_from_state_dict(
86
87
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )
88

89
    def forward(self, x: Tensor) -> Tensor:
90
91
92
93
94
95
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
96
        scale = w * (rv + self.eps).rsqrt()
97
98
        bias = b - rm * scale
        return x * scale + bias
99

100
    def __repr__(self) -> str:
101
        return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
102
103
104


class ConvNormActivation(torch.nn.Sequential):
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    """
    Configurable block used for Convolution-Normalzation-Activation blocks.

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
        kernel_size: (int, optional): Size of the convolving kernel. Default: 3
        stride (int, optional): Stride of the convolution. Default: 1
        padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
        activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
        dilation (int): Spacing between kernel elements. Default: 1
        inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
119
        bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
120
121
122

    """

123
124
125
126
127
128
129
130
131
132
133
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: Optional[int] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: int = 1,
134
        inplace: Optional[bool] = True,
135
        bias: Optional[bool] = None,
136
137
138
    ) -> None:
        if padding is None:
            padding = (kernel_size - 1) // 2 * dilation
139
140
        if bias is None:
            bias = norm_layer is None
141
142
143
144
145
146
147
148
149
        layers = [
            torch.nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                dilation=dilation,
                groups=groups,
150
                bias=bias,
151
152
            )
        ]
153
154
155
        if norm_layer is not None:
            layers.append(norm_layer(out_channels))
        if activation_layer is not None:
156
157
            params = {} if inplace is None else {"inplace": inplace}
            layers.append(activation_layer(**params))
158
        super().__init__(*layers)
Kai Zhang's avatar
Kai Zhang committed
159
        _log_api_usage_once(self)
160
161
162
163
        self.out_channels = out_channels


class SqueezeExcitation(torch.nn.Module):
164
165
166
167
168
169
170
171
172
173
174
    """
    This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
    Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3.

    Args:
        input_channels (int): Number of channels in the input image
        squeeze_channels (int): Number of squeeze channels
        activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
        scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
    """

175
176
177
178
179
180
181
182
    def __init__(
        self,
        input_channels: int,
        squeeze_channels: int,
        activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
        scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
    ) -> None:
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
183
        _log_api_usage_once(self)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
        self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
        self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
        self.activation = activation()
        self.scale_activation = scale_activation()

    def _scale(self, input: Tensor) -> Tensor:
        scale = self.avgpool(input)
        scale = self.fc1(scale)
        scale = self.activation(scale)
        scale = self.fc2(scale)
        return self.scale_activation(scale)

    def forward(self, input: Tensor) -> Tensor:
        scale = self._scale(input)
        return scale * input