""" helper class that supports empty tensors on some nn functions. Ideally, add support directly in PyTorch to empty tensors in those functions. This can be removed once https://github.com/pytorch/pytorch/issues/12013 is implemented """ import warnings import torch from torch import Tensor, Size from torch.jit.annotations import List, Optional, Tuple class Conv2d(torch.nn.Conv2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.Conv2d is deprecated and will be " "removed in future versions, use torch.nn.Conv2d instead.", FutureWarning) class ConvTranspose2d(torch.nn.ConvTranspose2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.ConvTranspose2d is deprecated and will be " "removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning) class BatchNorm2d(torch.nn.BatchNorm2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.BatchNorm2d is deprecated and will be " "removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning) interpolate = torch.nn.functional.interpolate # This is not in nn class FrozenBatchNorm2d(torch.nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed """ def __init__( self, num_features: int, eps: float = 0., n: Optional[int] = None, ): # n=None for backward-compatibility if n is not None: warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning) num_features = n super(FrozenBatchNorm2d, self).__init__() self.eps = eps self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features)) def _load_from_state_dict( self, state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): num_batches_tracked_key = prefix + 'num_batches_tracked' if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super(FrozenBatchNorm2d, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def forward(self, x: Tensor) -> Tensor: # move reshapes to the beginning # to make it fuser-friendly w = self.weight.reshape(1, -1, 1, 1) b = self.bias.reshape(1, -1, 1, 1) rv = self.running_var.reshape(1, -1, 1, 1) rm = self.running_mean.reshape(1, -1, 1, 1) scale = w * (rv + self.eps).rsqrt() bias = b - rm * scale return x * scale + bias def __repr__(self) -> str: return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"