Unverified Commit 005355bd authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added eps in the __repr__ of FrozenBN (#2852)

* feat: Updated FrozenBN eps to align with BatchNorm

* feat: Added eps to __repr__ of FrozenBN

* test: Updated unittest of __repr__ for FrozenBN

* test: Updated unittest for eps value in BN and FrozenBN

* fix: Revert FrozenBN eps value

* test: Revert test on eps alignment between FrozenBN and BN
parent 6713f034
...@@ -607,10 +607,11 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -607,10 +607,11 @@ class DeformConvTester(OpTester, unittest.TestCase):
class FrozenBNTester(unittest.TestCase): class FrozenBNTester(unittest.TestCase):
def test_frozenbatchnorm2d_repr(self): def test_frozenbatchnorm2d_repr(self):
num_features = 32 num_features = 32
t = ops.misc.FrozenBatchNorm2d(num_features) eps = 1e-5
t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
# Check integrity of object __repr__ attribute # Check integrity of object __repr__ attribute
expected_string = f"FrozenBatchNorm2d({num_features})" expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
self.assertEqual(t.__repr__(), expected_string) self.assertEqual(t.__repr__(), expected_string)
def test_frozenbatchnorm2d_eps(self): def test_frozenbatchnorm2d_eps(self):
......
...@@ -96,4 +96,4 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -96,4 +96,4 @@ class FrozenBatchNorm2d(torch.nn.Module):
return x * scale + bias return x * scale + bias
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]})" return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment