"...text-generation-inference.git" did not exist on "069895b9859b776cab8145f3fa6f6d16ac40af47"
Commit 5c6144e6 authored by Michael Carilli's avatar Michael Carilli
Browse files

FP16_Optimizer now preserves param order and casts per-param state tensors to FP32

parent 4a8cf7ad
...@@ -13,7 +13,7 @@ from .fp16util import ( ...@@ -13,7 +13,7 @@ from .fp16util import (
from .fused_weight_norm import Fused_Weight_Norm from .fused_weight_norm import Fused_Weight_Norm
from .fp16_optimizer import fp32_to_fp16, fp16_to_fp32, FP16_Module, FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
from .loss_scaler import LossScaler, DynamicLossScaler from .loss_scaler import LossScaler, DynamicLossScaler
...@@ -7,48 +7,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -7,48 +7,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16_Module(nn.Module):
def __init__(self, module):
super(FP16_Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
# TODO: Update overflow check + downscale to use Carl's fused kernel. # TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object): class FP16_Optimizer(object):
""" """
...@@ -151,40 +109,54 @@ class FP16_Optimizer(object): ...@@ -151,40 +109,54 @@ class FP16_Optimizer(object):
if not torch.cuda.is_available: if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.") raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
# init_state_dict sets up an alternative way to cast per-param state tensors.
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
# init_state_dict = init_optimizer.state_dict()
self.fp16_groups = [] self.fp16_groups = []
self.fp32_from_fp16_groups = [] self.fp32_from_fp16_groups = []
self.fp32_from_fp32_groups = [] self.fp32_from_fp32_groups = []
for i, param_group in enumerate(init_optimizer.param_groups): for i, param_group in enumerate(self.optimizer.param_groups):
print("FP16_Optimizer processing param group {}:".format(i)) print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = [] fp16_params_this_group = []
fp32_params_this_group = [] fp32_params_this_group = []
master_params_this_group = []
fp32_from_fp16_params_this_group = []
for param in param_group['params']: for param in param_group['params']:
if param.requires_grad: if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor': if param.type() == 'torch.cuda.HalfTensor':
print("FP16_Optimizer received torch.cuda.HalfTensor with {}" print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
.format(param.size())) .format(param.size()))
fp16_params_this_group.append(param) fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
master_params_this_group.append(master_param)
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
if param in self.optimizer.state:
self.optimizer.state[master_param] = self.optimizer.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
print("FP16_Optimizer received torch.cuda.FloatTensor with {}" print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size())) .format(param.size()))
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
master_params_this_group.append(param)
else: else:
raise TypeError("Wrapped parameters must be either " raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type())) "Received {}".format(param.type()))
fp32_from_fp16_params_this_group = [param.detach().clone().float() param_group['params'] = master_params_this_group
for param in fp16_params_this_group]
for param in fp32_from_fp16_params_this_group:
param.requires_grad = True
param_group['params'] = fp32_from_fp16_params_this_group + fp32_params_this_group
self.fp16_groups.append(fp16_params_this_group) self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
self.optimizer = init_optimizer # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# alternative way to cast per-param state tensors:
# self.optimizer.load_state_dict(init_state_dict)
if dynamic_loss_scale: if dynamic_loss_scale:
self.dynamic_loss_scale = True self.dynamic_loss_scale = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment