Unverified Commit a9ef0f99 authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

Benchmarks - Keep BatchNorm as fp32 for pytorch cnn models cast to fp16 (#322)

**Description**
The BatchNorm operator is not numerically stable in fp16.  PyTorch documentation recommends to keep the BN op in fp32 for fp16 AMP models.  Refer to https://pytorch.org/docs/stable/amp.html#ops-that-can-autocast-to-float32.  Preserving BN in fp32 for superbench more accurately reflects real workloads.
parent 425b9ff8
......@@ -15,6 +15,14 @@
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
def _keep_BatchNorm_as_float(module):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
_keep_BatchNorm_as_float(child)
return module
class PytorchCNN(PytorchBase):
"""The CNN benchmark class."""
def __init__(self, name, parameters=''):
......@@ -63,6 +71,7 @@ def _create_model(self, precision):
try:
self._model = getattr(models, self._args.model_type)()
self._model = self._model.to(dtype=getattr(torch, precision.value))
self._model = _keep_BatchNorm_as_float(self._model)
if self._gpu_available:
self._model = self._model.cuda()
except BaseException as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment