You need to sign in or sign up before continuing.
Unverified Commit 5db8998a authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added number of features in FrozenBatchNorm2d __repr__ (#2168)

* feat: Added number of features in FrozenBatchNorm2d repr

While BatchNorm layers have extensive information in their repr, FrozenBatchNorm2d has one

* refactor: Refactored FrozenBatchNorm2d __repr__

* test: Added unittest for FrozenBatchNorm2d __repr__

* style: Removed blank lines in test_ops

* refactor: Avoids creating an extra attribute for __repr__

* style: Switched __repr__ to f-string

Since support of Python version ealier than 3.6 have been dropped, f-string can be used.

* fix: Fixed typo in __repr__

* style: Switched unittest .format to f-string
parent c7147af0
...@@ -538,5 +538,15 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -538,5 +538,15 @@ class DeformConvTester(OpTester, unittest.TestCase):
(x, offset, weight, bias), nondet_tol=1e-5) (x, offset, weight, bias), nondet_tol=1e-5)
class FrozenBNTester(unittest.TestCase):
def test_frozenbatchnorm2d_repr(self):
num_features = 32
t = ops.misc.FrozenBatchNorm2d(num_features)
# Check integrity of object __repr__ attribute
expected_string = f"FrozenBatchNorm2d({num_features})"
self.assertEqual(t.__repr__(), expected_string)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -151,3 +151,6 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -151,3 +151,6 @@ class FrozenBatchNorm2d(torch.nn.Module):
scale = w * rv.rsqrt() scale = w * rv.rsqrt()
bias = b - rm * scale bias = b - rm * scale
return x * scale + bias return x * scale + bias
def __repr__(self):
return f"{self.__class__.__name__}({self.weight.shape[0]})"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment