Unverified Commit 45e027c7 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[BC-breaking] Change default eps value of FrozenBN (#2933)

* Change default eps value of FrozenBN.

* Update the unit-tests.`

* Update the expected values.

* Revert the expected value and use original eps=0 value for flaky tests.

* Post init change of eps.

* Styles.
parent 455cd57c
......@@ -6,9 +6,10 @@ import torch.nn as nn
import numpy as np
from torchvision import models
import unittest
import traceback
import random
from torchvision.ops.misc import FrozenBatchNorm2d
def set_rng_seed(seed):
torch.manual_seed(seed)
......@@ -149,6 +150,10 @@ class ModelTester(TestCase):
if "retinanet" in name:
kwargs["score_thresh"] = 0.013
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
if "keypointrcnn" in name or "retinanet" in name:
for module in model.modules():
if isinstance(module, FrozenBatchNorm2d):
module.eps = 0
model.eval().to(device=dev)
input_shape = (3, 300, 300)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
......
......@@ -623,10 +623,10 @@ class FrozenBNTester(unittest.TestCase):
running_var=torch.rand(sample_size[1]),
num_batches_tracked=torch.tensor(100))
# Check that default eps is zero for backward-compatibility
# Check that default eps is equal to the one of BN
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval()
bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
bn.load_state_dict(state_dict)
# Difference is expected to fall in an acceptable range
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
......
......@@ -51,7 +51,7 @@ class FrozenBatchNorm2d(torch.nn.Module):
def __init__(
self,
num_features: int,
eps: float = 0.,
eps: float = 1e-5,
n: Optional[int] = None,
):
# n=None for backward-compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment