Commit 713e0fb8 authored by Jerry Ma's avatar Jerry Ma
Browse files

Better FP16 support in pytorch fp16 utils.

This commit adds an FP16Model class as a successor to network_to_half.

The benefits of this class are:

- Preservation of single-precision for BatchNorm layers. The models
  generated by network_to_half() convert BatchNorm moment tensors to
  half-precision, then back to single-precision, which hurts the
  accuracy of the moment estimators and occasionally results in NaNs.
- Support for multi-argument nn.Modules (self-explanatory from code).
parent cc85a2e5
......@@ -3,10 +3,13 @@ from .fp16util import (
network_to_half,
prep_param_lists,
model_grads_to_master_grads,
master_params_to_model_params,
master_params_to_model_params,
tofp16,
to_python_float,
clip_grad_norm,
convert_module,
convert_network,
FP16Model,
)
......
......@@ -3,9 +3,10 @@ import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class tofp16(nn.Module):
"""
Model wrapper that implements::
Utility module that implements::
def forward(self, input):
return input.half()
......@@ -19,14 +20,11 @@ class tofp16(nn.Module):
def BN_convert_float(module):
'''
Designed to work with network_to_half.
BatchNorm layers need parameters in single precision.
Find all layers and convert them back to float. This can't
be done with built in .apply as that function will apply
fn to all modules, parameters, and buffers. Thus we wouldn't
be able to guard the float conversion based on the module type.
'''
"""
Utility function for network_to_half().
Retained for legacy purposes.
"""
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
module.float()
for child in module.children():
......@@ -37,16 +35,59 @@ def BN_convert_float(module):
def network_to_half(network):
"""
Convert model to half precision in a batchnorm-safe way.
Retained for legacy purposes. It is recommended to use FP16Model.
"""
return nn.Sequential(tofp16(), BN_convert_float(network.half()))
def convert_module(module, dtype):
"""
Converts a module's immediate parameters and buffers to dtype.
"""
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)
for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)
def convert_network(network, dtype):
"""
Converts a network's parameters and buffers to dtype.
"""
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
return network
class FP16Model(nn.Module):
"""
Convert model to half precision in a batchnorm-safe way.
"""
def __init__(self, network):
super(FP16Model, self).__init__()
self.network = convert_network(network, dtype=torch.half)
def forward(self, *inputs):
inputs = tuple(t.half() for t in inputs)
return self.network(*inputs)
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
Creates a list of FP32 master parameters for a given model, as in
`Training Neural Networks with Mixed Precision: Real Examples`_.
Args:
......
import unittest
import torch
import torch.nn as nn
from apex.fp16_utils import FP16Model
class DummyBlock(nn.Module):
def __init__(self):
super(DummyBlock, self).__init__()
self.conv = nn.Conv2d(10, 10, 2)
self.bn = nn.BatchNorm2d(10, affine=True)
def forward(self, x):
return self.conv(self.bn(x))
class DummyNet(nn.Module):
def __init__(self):
super(DummyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 2)
self.bn1 = nn.BatchNorm2d(10, affine=False)
self.db1 = DummyBlock()
self.db2 = DummyBlock()
def forward(self, x):
out = x
out = self.conv1(out)
out = self.bn1(out)
out = self.db1(out)
out = self.db2(out)
return out
class DummyNetWrapper(nn.Module):
def __init__(self):
super(DummyNetWrapper, self).__init__()
self.bn = nn.BatchNorm2d(3, affine=True)
self.dn = DummyNet()
def forward(self, x):
return self.dn(self.bn(x))
class TestFP16Model(unittest.TestCase):
def setUp(self):
self.N = 64
self.C_in = 3
self.H_in = 16
self.W_in = 32
self.in_tensor = torch.randn((self.N, self.C_in, self.H_in, self.W_in)).cuda()
self.orig_model = DummyNetWrapper().cuda()
self.fp16_model = FP16Model(self.orig_model)
def test_params_and_buffers(self):
exempted_modules = [
self.fp16_model.network.bn,
self.fp16_model.network.dn.db1.bn,
self.fp16_model.network.dn.db2.bn,
]
for m in self.fp16_model.modules():
expected_dtype = torch.float if (m in exempted_modules) else torch.half
for p in m.parameters(recurse=False):
assert p.dtype == expected_dtype
for b in m.buffers(recurse=False):
assert b.dtype in (expected_dtype, torch.int64)
def test_output_is_half(self):
out_tensor = self.fp16_model(self.in_tensor)
assert out_tensor.dtype == torch.half
import unittest
import sys
test_dirs = ["run_amp", "run_mixed_adam"]
test_dirs = ["run_amp", "run_fp16util", "run_mixed_adam"]
runner = unittest.TextTestRunner(verbosity=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment