Commit 713e0fb8 authored by Jerry Ma's avatar Jerry Ma
Browse files

Better FP16 support in pytorch fp16 utils.

This commit adds an FP16Model class as a successor to network_to_half.

The benefits of this class are:

- Preservation of single-precision for BatchNorm layers. The models
  generated by network_to_half() convert BatchNorm moment tensors to
  half-precision, then back to single-precision, which hurts the
  accuracy of the moment estimators and occasionally results in NaNs.
- Support for multi-argument nn.Modules (self-explanatory from code).
parent cc85a2e5
...@@ -7,6 +7,9 @@ from .fp16util import ( ...@@ -7,6 +7,9 @@ from .fp16util import (
tofp16, tofp16,
to_python_float, to_python_float,
clip_grad_norm, clip_grad_norm,
convert_module,
convert_network,
FP16Model,
) )
......
...@@ -3,9 +3,10 @@ import torch.nn as nn ...@@ -3,9 +3,10 @@ import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class tofp16(nn.Module): class tofp16(nn.Module):
""" """
Model wrapper that implements:: Utility module that implements::
def forward(self, input): def forward(self, input):
return input.half() return input.half()
...@@ -19,14 +20,11 @@ class tofp16(nn.Module): ...@@ -19,14 +20,11 @@ class tofp16(nn.Module):
def BN_convert_float(module): def BN_convert_float(module):
''' """
Designed to work with network_to_half. Utility function for network_to_half().
BatchNorm layers need parameters in single precision.
Find all layers and convert them back to float. This can't Retained for legacy purposes.
be done with built in .apply as that function will apply """
fn to all modules, parameters, and buffers. Thus we wouldn't
be able to guard the float conversion based on the module type.
'''
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
module.float() module.float()
for child in module.children(): for child in module.children():
...@@ -37,10 +35,53 @@ def BN_convert_float(module): ...@@ -37,10 +35,53 @@ def BN_convert_float(module):
def network_to_half(network): def network_to_half(network):
""" """
Convert model to half precision in a batchnorm-safe way. Convert model to half precision in a batchnorm-safe way.
Retained for legacy purposes. It is recommended to use FP16Model.
""" """
return nn.Sequential(tofp16(), BN_convert_float(network.half())) return nn.Sequential(tofp16(), BN_convert_float(network.half()))
def convert_module(module, dtype):
"""
Converts a module's immediate parameters and buffers to dtype.
"""
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)
for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)
def convert_network(network, dtype):
"""
Converts a network's parameters and buffers to dtype.
"""
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
return network
class FP16Model(nn.Module):
"""
Convert model to half precision in a batchnorm-safe way.
"""
def __init__(self, network):
super(FP16Model, self).__init__()
self.network = convert_network(network, dtype=torch.half)
def forward(self, *inputs):
inputs = tuple(t.half() for t in inputs)
return self.network(*inputs)
def backwards_debug_hook(grad): def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!") raise RuntimeError("master_params recieved a gradient in the backward pass!")
......
import unittest
import torch
import torch.nn as nn
from apex.fp16_utils import FP16Model
class DummyBlock(nn.Module):
def __init__(self):
super(DummyBlock, self).__init__()
self.conv = nn.Conv2d(10, 10, 2)
self.bn = nn.BatchNorm2d(10, affine=True)
def forward(self, x):
return self.conv(self.bn(x))
class DummyNet(nn.Module):
def __init__(self):
super(DummyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 2)
self.bn1 = nn.BatchNorm2d(10, affine=False)
self.db1 = DummyBlock()
self.db2 = DummyBlock()
def forward(self, x):
out = x
out = self.conv1(out)
out = self.bn1(out)
out = self.db1(out)
out = self.db2(out)
return out
class DummyNetWrapper(nn.Module):
def __init__(self):
super(DummyNetWrapper, self).__init__()
self.bn = nn.BatchNorm2d(3, affine=True)
self.dn = DummyNet()
def forward(self, x):
return self.dn(self.bn(x))
class TestFP16Model(unittest.TestCase):
def setUp(self):
self.N = 64
self.C_in = 3
self.H_in = 16
self.W_in = 32
self.in_tensor = torch.randn((self.N, self.C_in, self.H_in, self.W_in)).cuda()
self.orig_model = DummyNetWrapper().cuda()
self.fp16_model = FP16Model(self.orig_model)
def test_params_and_buffers(self):
exempted_modules = [
self.fp16_model.network.bn,
self.fp16_model.network.dn.db1.bn,
self.fp16_model.network.dn.db2.bn,
]
for m in self.fp16_model.modules():
expected_dtype = torch.float if (m in exempted_modules) else torch.half
for p in m.parameters(recurse=False):
assert p.dtype == expected_dtype
for b in m.buffers(recurse=False):
assert b.dtype in (expected_dtype, torch.int64)
def test_output_is_half(self):
out_tensor = self.fp16_model(self.in_tensor)
assert out_tensor.dtype == torch.half
import unittest import unittest
import sys import sys
test_dirs = ["run_amp", "run_mixed_adam"] test_dirs = ["run_amp", "run_fp16util", "run_mixed_adam"]
runner = unittest.TextTestRunner(verbosity=2) runner = unittest.TextTestRunner(verbosity=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment