Commit 8db3f95c authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' into api_refactor

parents 1f693b92 1b903852
......@@ -3,10 +3,13 @@ from .fp16util import (
network_to_half,
prep_param_lists,
model_grads_to_master_grads,
master_params_to_model_params,
master_params_to_model_params,
tofp16,
to_python_float,
clip_grad_norm,
convert_module,
convert_network,
FP16Model,
)
from .fp16_optimizer import FP16_Optimizer
......
......@@ -3,9 +3,10 @@ import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class tofp16(nn.Module):
"""
Model wrapper that implements::
Utility module that implements::
def forward(self, input):
return input.half()
......@@ -19,14 +20,11 @@ class tofp16(nn.Module):
def BN_convert_float(module):
'''
Designed to work with network_to_half.
BatchNorm layers need parameters in single precision.
Find all layers and convert them back to float. This can't
be done with built in .apply as that function will apply
fn to all modules, parameters, and buffers. Thus we wouldn't
be able to guard the float conversion based on the module type.
'''
"""
Utility function for network_to_half().
Retained for legacy purposes.
"""
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
module.float()
for child in module.children():
......@@ -37,16 +35,59 @@ def BN_convert_float(module):
def network_to_half(network):
"""
Convert model to half precision in a batchnorm-safe way.
Retained for legacy purposes. It is recommended to use FP16Model.
"""
return nn.Sequential(tofp16(), BN_convert_float(network.half()))
def convert_module(module, dtype):
"""
Converts a module's immediate parameters and buffers to dtype.
"""
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)
for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)
def convert_network(network, dtype):
"""
Converts a network's parameters and buffers to dtype.
"""
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
return network
class FP16Model(nn.Module):
"""
Convert model to half precision in a batchnorm-safe way.
"""
def __init__(self, network):
super(FP16Model, self).__init__()
self.network = convert_network(network, dtype=torch.half)
def forward(self, *inputs):
inputs = tuple(t.half() for t in inputs)
return self.network(*inputs)
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
Creates a list of FP32 master parameters for a given model, as in
`Training Neural Networks with Mixed Precision: Real Examples`_.
Args:
......
import unittest
import torch
import torch.nn as nn
from apex.fp16_utils import FP16Model
class DummyBlock(nn.Module):
def __init__(self):
super(DummyBlock, self).__init__()
self.conv = nn.Conv2d(10, 10, 2)
self.bn = nn.BatchNorm2d(10, affine=True)
def forward(self, x):
return self.conv(self.bn(x))
class DummyNet(nn.Module):
def __init__(self):
super(DummyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 2)
self.bn1 = nn.BatchNorm2d(10, affine=False)
self.db1 = DummyBlock()
self.db2 = DummyBlock()
def forward(self, x):
out = x
out = self.conv1(out)
out = self.bn1(out)
out = self.db1(out)
out = self.db2(out)
return out
class DummyNetWrapper(nn.Module):
def __init__(self):
super(DummyNetWrapper, self).__init__()
self.bn = nn.BatchNorm2d(3, affine=True)
self.dn = DummyNet()
def forward(self, x):
return self.dn(self.bn(x))
class TestFP16Model(unittest.TestCase):
def setUp(self):
self.N = 64
self.C_in = 3
self.H_in = 16
self.W_in = 32
self.in_tensor = torch.randn((self.N, self.C_in, self.H_in, self.W_in)).cuda()
self.orig_model = DummyNetWrapper().cuda()
self.fp16_model = FP16Model(self.orig_model)
def test_params_and_buffers(self):
exempted_modules = [
self.fp16_model.network.bn,
self.fp16_model.network.dn.db1.bn,
self.fp16_model.network.dn.db2.bn,
]
for m in self.fp16_model.modules():
expected_dtype = torch.float if (m in exempted_modules) else torch.half
for p in m.parameters(recurse=False):
assert p.dtype == expected_dtype
for b in m.buffers(recurse=False):
assert b.dtype in (expected_dtype, torch.int64)
def test_output_is_half(self):
out_tensor = self.fp16_model(self.in_tensor)
assert out_tensor.dtype == torch.half
import unittest
import sys
test_dirs = ["run_amp", "run_mixed_adam"]
test_dirs = ["run_amp", "run_fp16util", "run_mixed_adam"]
runner = unittest.TextTestRunner(verbosity=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment