Unverified Commit 7a2d0618 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added eps attribute to FrozenBatchNorm2d (#2190)

* feat: Added eps argument to FrozenBatchNorm2d

* test: Added unittest for eps addition in FrozenBatchNorm2d

See #2169

* fix: Reverted forward changes for JIT fuser

* fix: Added back n argument for backward-compatibility

* fix: Fixed FrozenBatchNorm2d forward

Added back eps

* feat: Specified deprecation warnings in FrozenBatchNorm2d

* test: Added unittest for deprecation warninig in FrozenBatchNorm2d

* style: Fixed lint

* style: Fixed block comment lint
parent a09d129c
...@@ -547,6 +547,35 @@ class FrozenBNTester(unittest.TestCase): ...@@ -547,6 +547,35 @@ class FrozenBNTester(unittest.TestCase):
expected_string = f"FrozenBatchNorm2d({num_features})" expected_string = f"FrozenBatchNorm2d({num_features})"
self.assertEqual(t.__repr__(), expected_string) self.assertEqual(t.__repr__(), expected_string)
def test_frozenbatchnorm2d_eps(self):
sample_size = (4, 32, 28, 28)
x = torch.rand(sample_size)
state_dict = dict(weight=torch.rand(sample_size[1]),
bias=torch.rand(sample_size[1]),
running_mean=torch.rand(sample_size[1]),
running_var=torch.rand(sample_size[1]),
num_batches_tracked=torch.tensor(100))
# Check that default eps is zero for backward-compatibility
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval()
bn.load_state_dict(state_dict)
# Difference is expected to fall in an acceptable range
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
# Check computation for eps > 0
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
bn.load_state_dict(state_dict)
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
def test_frozenbatchnorm2d_n_arg(self):
"""Ensure a warning is thrown when passing `n` kwarg
(remove this when support of `n` is dropped)"""
self.assertWarns(DeprecationWarning, ops.misc.FrozenBatchNorm2d, 32, eps=1e-5, n=32)
class BoxConversionTester(unittest.TestCase): class BoxConversionTester(unittest.TestCase):
@staticmethod @staticmethod
......
...@@ -13,6 +13,7 @@ is implemented ...@@ -13,6 +13,7 @@ is implemented
""" """
import math import math
import warnings
import torch import torch
from torchvision.ops import _new_empty_tensor from torchvision.ops import _new_empty_tensor
from torch.nn import Module, Conv2d from torch.nn import Module, Conv2d
...@@ -124,12 +125,18 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -124,12 +125,18 @@ class FrozenBatchNorm2d(torch.nn.Module):
are fixed are fixed
""" """
def __init__(self, n): def __init__(self, num_features, eps=0., n=None):
# n=None for backward-compatibility
if n is not None:
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
DeprecationWarning)
num_features = n
super(FrozenBatchNorm2d, self).__init__() super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n)) self.eps = eps
self.register_buffer("bias", torch.zeros(n)) self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(n)) self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs): missing_keys, unexpected_keys, error_msgs):
...@@ -148,7 +155,7 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -148,7 +155,7 @@ class FrozenBatchNorm2d(torch.nn.Module):
b = self.bias.reshape(1, -1, 1, 1) b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1) rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1) rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * rv.rsqrt() scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale bias = b - rm * scale
return x * scale + bias return x * scale + bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment