# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings import torch from torch import Tensor from typing import Callable, List, Optional class FrozenBatchNorm2d(torch.nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed """ def __init__( self, num_features: int, eps: float = 1e-5, n: Optional[int] = None, ): # n=None for backward-compatibility if n is not None: warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning) num_features = n super(FrozenBatchNorm2d, self).__init__() self.eps = eps self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features)) def _load_from_state_dict( self, state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): num_batches_tracked_key = prefix + 'num_batches_tracked' if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super(FrozenBatchNorm2d, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) # one-time preprocessing self.weight = self.weight.reshape(1, -1, 1, 1) self.bias = self.bias.reshape(1, -1, 1, 1) self.running_var = self.running_var.reshape(1, -1, 1, 1) self.running_mean = self.running_mean.reshape(1, -1, 1, 1) # registering these variables as buffers self.register_buffer("scale", self.weight * (self.running_var + self.eps).rsqrt()) self.register_buffer("bias_term", self.bias - self.running_mean * self.scale) def forward(self, x: Tensor) -> Tensor: return x * self.scale + self.bias_term def __repr__(self) -> str: return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"