""" helper class that supports empty tensors on some nn functions. Ideally, add support directly in PyTorch to empty tensors in those functions. This can be removed once https://github.com/pytorch/pytorch/issues/12013 is implemented """ import warnings from typing import Callable, List, Optional import torch from torch import Tensor class Conv2d(torch.nn.Conv2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.Conv2d is deprecated and will be " "removed in future versions, use torch.nn.Conv2d instead.", FutureWarning, ) class ConvTranspose2d(torch.nn.ConvTranspose2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.ConvTranspose2d is deprecated and will be " "removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning, ) class BatchNorm2d(torch.nn.BatchNorm2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.BatchNorm2d is deprecated and will be " "removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning, ) interpolate = torch.nn.functional.interpolate # This is not in nn class FrozenBatchNorm2d(torch.nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed """ def __init__( self, num_features: int, eps: float = 1e-5, n: Optional[int] = None, ): # n=None for backward-compatibility if n is not None: warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning) num_features = n super().__init__() self.eps = eps self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features)) def _load_from_state_dict( self, state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def forward(self, x: Tensor) -> Tensor: # move reshapes to the beginning # to make it fuser-friendly w = self.weight.reshape(1, -1, 1, 1) b = self.bias.reshape(1, -1, 1, 1) rv = self.running_var.reshape(1, -1, 1, 1) rm = self.running_mean.reshape(1, -1, 1, 1) scale = w * (rv + self.eps).rsqrt() bias = b - rm * scale return x * scale + bias def __repr__(self) -> str: return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" class ConvNormActivation(torch.nn.Sequential): def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: Optional[int] = None, groups: int = 1, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, inplace: bool = True, ) -> None: if padding is None: padding = (kernel_size - 1) // 2 * dilation layers = [ torch.nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=norm_layer is None, ) ] if norm_layer is not None: layers.append(norm_layer(out_channels)) if activation_layer is not None: layers.append(activation_layer(inplace=inplace)) super().__init__(*layers) self.out_channels = out_channels class SqueezeExcitation(torch.nn.Module): def __init__( self, input_channels: int, squeeze_channels: int, activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, ) -> None: super().__init__() self.avgpool = torch.nn.AdaptiveAvgPool2d(1) self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) self.activation = activation() self.scale_activation = scale_activation() def _scale(self, input: Tensor) -> Tensor: scale = self.avgpool(input) scale = self.fc1(scale) scale = self.activation(scale) scale = self.fc2(scale) return self.scale_activation(scale) def forward(self, input: Tensor) -> Tensor: scale = self._scale(input) return scale * input