Commit 8db3f95c authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' into api_refactor

parents 1f693b92 1b903852
...@@ -7,6 +7,9 @@ from .fp16util import ( ...@@ -7,6 +7,9 @@ from .fp16util import (
tofp16, tofp16,
to_python_float, to_python_float,
clip_grad_norm, clip_grad_norm,
convert_module,
convert_network,
FP16Model,
) )
from .fp16_optimizer import FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
......
...@@ -3,9 +3,10 @@ import torch.nn as nn ...@@ -3,9 +3,10 @@ import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class tofp16(nn.Module): class tofp16(nn.Module):
""" """
Model wrapper that implements:: Utility module that implements::
def forward(self, input): def forward(self, input):
return input.half() return input.half()
...@@ -19,14 +20,11 @@ class tofp16(nn.Module): ...@@ -19,14 +20,11 @@ class tofp16(nn.Module):
def BN_convert_float(module): def BN_convert_float(module):
''' """
Designed to work with network_to_half. Utility function for network_to_half().
BatchNorm layers need parameters in single precision.
Find all layers and convert them back to float. This can't Retained for legacy purposes.
be done with built in .apply as that function will apply """
fn to all modules, parameters, and buffers. Thus we wouldn't
be able to guard the float conversion based on the module type.
'''
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
module.float() module.float()
for child in module.children(): for child in module.children():
...@@ -37,10 +35,53 @@ def BN_convert_float(module): ...@@ -37,10 +35,53 @@ def BN_convert_float(module):
def network_to_half(network): def network_to_half(network):
""" """
Convert model to half precision in a batchnorm-safe way. Convert model to half precision in a batchnorm-safe way.
Retained for legacy purposes. It is recommended to use FP16Model.
""" """
return nn.Sequential(tofp16(), BN_convert_float(network.half())) return nn.Sequential(tofp16(), BN_convert_float(network.half()))
def convert_module(module, dtype):
"""
Converts a module's immediate parameters and buffers to dtype.
"""
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)
for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)
def convert_network(network, dtype):
"""
Converts a network's parameters and buffers to dtype.
"""
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
return network
class FP16Model(nn.Module):
"""
Convert model to half precision in a batchnorm-safe way.
"""
def __init__(self, network):
super(FP16Model, self).__init__()
self.network = convert_network(network, dtype=torch.half)
def forward(self, *inputs):
inputs = tuple(t.half() for t in inputs)
return self.network(*inputs)
def backwards_debug_hook(grad): def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!") raise RuntimeError("master_params recieved a gradient in the backward pass!")
......
import unittest
import torch
import torch.nn as nn
from apex.fp16_utils import FP16Model
class DummyBlock(nn.Module):
def __init__(self):
super(DummyBlock, self).__init__()
self.conv = nn.Conv2d(10, 10, 2)
self.bn = nn.BatchNorm2d(10, affine=True)
def forward(self, x):
return self.conv(self.bn(x))
class DummyNet(nn.Module):
def __init__(self):
super(DummyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 2)
self.bn1 = nn.BatchNorm2d(10, affine=False)
self.db1 = DummyBlock()
self.db2 = DummyBlock()
def forward(self, x):
out = x
out = self.conv1(out)
out = self.bn1(out)
out = self.db1(out)
out = self.db2(out)
return out
class DummyNetWrapper(nn.Module):
def __init__(self):
super(DummyNetWrapper, self).__init__()
self.bn = nn.BatchNorm2d(3, affine=True)
self.dn = DummyNet()
def forward(self, x):
return self.dn(self.bn(x))
class TestFP16Model(unittest.TestCase):
def setUp(self):
self.N = 64
self.C_in = 3
self.H_in = 16
self.W_in = 32
self.in_tensor = torch.randn((self.N, self.C_in, self.H_in, self.W_in)).cuda()
self.orig_model = DummyNetWrapper().cuda()
self.fp16_model = FP16Model(self.orig_model)
def test_params_and_buffers(self):
exempted_modules = [
self.fp16_model.network.bn,
self.fp16_model.network.dn.db1.bn,
self.fp16_model.network.dn.db2.bn,
]
for m in self.fp16_model.modules():
expected_dtype = torch.float if (m in exempted_modules) else torch.half
for p in m.parameters(recurse=False):
assert p.dtype == expected_dtype
for b in m.buffers(recurse=False):
assert b.dtype in (expected_dtype, torch.int64)
def test_output_is_half(self):
out_tensor = self.fp16_model(self.in_tensor)
assert out_tensor.dtype == torch.half
import unittest import unittest
import sys import sys
test_dirs = ["run_amp", "run_mixed_adam"] test_dirs = ["run_amp", "run_fp16util", "run_mixed_adam"]
runner = unittest.TextTestRunner(verbosity=2) runner = unittest.TextTestRunner(verbosity=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment