Unverified Commit 45e027c7 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[BC-breaking] Change default eps value of FrozenBN (#2933)

* Change default eps value of FrozenBN.

* Update the unit-tests.`

* Update the expected values.

* Revert the expected value and use original eps=0 value for flaky tests.

* Post init change of eps.

* Styles.
parent 455cd57c
...@@ -6,9 +6,10 @@ import torch.nn as nn ...@@ -6,9 +6,10 @@ import torch.nn as nn
import numpy as np import numpy as np
from torchvision import models from torchvision import models
import unittest import unittest
import traceback
import random import random
from torchvision.ops.misc import FrozenBatchNorm2d
def set_rng_seed(seed): def set_rng_seed(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -149,6 +150,10 @@ class ModelTester(TestCase): ...@@ -149,6 +150,10 @@ class ModelTester(TestCase):
if "retinanet" in name: if "retinanet" in name:
kwargs["score_thresh"] = 0.013 kwargs["score_thresh"] = 0.013
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
if "keypointrcnn" in name or "retinanet" in name:
for module in model.modules():
if isinstance(module, FrozenBatchNorm2d):
module.eps = 0
model.eval().to(device=dev) model.eval().to(device=dev)
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
......
...@@ -623,10 +623,10 @@ class FrozenBNTester(unittest.TestCase): ...@@ -623,10 +623,10 @@ class FrozenBNTester(unittest.TestCase):
running_var=torch.rand(sample_size[1]), running_var=torch.rand(sample_size[1]),
num_batches_tracked=torch.tensor(100)) num_batches_tracked=torch.tensor(100))
# Check that default eps is zero for backward-compatibility # Check that default eps is equal to the one of BN
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1]) fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
fbn.load_state_dict(state_dict, strict=False) fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval() bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
bn.load_state_dict(state_dict) bn.load_state_dict(state_dict)
# Difference is expected to fall in an acceptable range # Difference is expected to fall in an acceptable range
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
......
...@@ -51,7 +51,7 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -51,7 +51,7 @@ class FrozenBatchNorm2d(torch.nn.Module):
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
eps: float = 0., eps: float = 1e-5,
n: Optional[int] = None, n: Optional[int] = None,
): ):
# n=None for backward-compatibility # n=None for backward-compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment