Commit 15648029 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'FDecaYed-deyuf/fused_optimizer_v2'

parents 880ab925 b9f0995b
import torch
import xentropy_cuda
class SoftmaxCrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
losses, max_log_sum_exp = xentropy_cuda.forward(
logits, labels, smoothing, half_to_float)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels,
torch.FloatTensor([smoothing]),
torch.LongTensor([padding_idx]))
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==padding_idx.item(), 0)
grad_logits = xentropy_cuda.backward(
grad_loss.contiguous(), logits, max_log_sum_exp,
labels, smoothing.item())
return grad_logits, None, None, None, None
from .fused_sgd import FusedSGD
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fused_novograd import FusedNovoGrad
from .fused_lamb import FusedLAMB
from .fp16_optimizer import FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
...@@ -35,6 +35,8 @@ class FP16_Optimizer(object): ...@@ -35,6 +35,8 @@ class FP16_Optimizer(object):
dynamic_loss_args=None, dynamic_loss_args=None,
verbose=True): verbose=True):
print("\nfp16_optimizer is designed to only work with apex.optimizers, and will be removed in future")
print("To update, use updated optimizers with AMP.")
# The fused optimizer does all the work. We need this layer for two reason: # The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils # 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add new fused optimizer later # 2. keep common stuff here in case we need to add new fused optimizer later
......
import types
import torch import torch
import importlib from apex.multi_tensor_apply import multi_tensor_applier
from amp_C import multi_tensor_adam
class FusedAdam(torch.optim.Optimizer): class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
This version of fused adam implements 2 fusion:
- Fusion of operations within adam optimizer
- Apply operation on a list of tensor in single multi-tensor kernel by group
It is a breaking change over last version, as API changes and it no longer fuse grad norm and loss scaling.
It has been proposed in `Adam: A Method for Stochastic Optimization`_. It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments: Arguments:
...@@ -21,10 +26,8 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -21,10 +26,8 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_ algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam! (default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step, adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
adds eps to the bias-corrected second moment estimate before True for decoupled weight decay(also known as AdamW) (default: True)
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -32,116 +35,75 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -32,116 +35,75 @@ class FusedAdam(torch.optim.Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, params, def __init__(self, params, lr=1e-3, bias_correction=True,
lr=1e-3, bias_correction = True, betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, weight_decay=0., amsgrad=False):
weight_decay=0., max_grad_norm=0., amsgrad=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay)
max_grad_norm=max_grad_norm)
super(FusedAdam, self).__init__(params, defaults) super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.adam_w_mode = 1 if adam_w_mode else 0
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
closure (callable, optional): A closure that reevaluates the model closure (callable, optional): A closure that reevaluates the model
and returns the loss. and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
""" """
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
raise RuntimeError('FusedAdam has been updated, please use with AMP for mixed precision.')
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if grads is None: for group in self.param_groups:
grads_group = [None]*len(self.param_groups) bias_correction = 1 if group['bias_correction'] else 0
# backward compatibility beta1, beta2 = group['betas']
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0])!=list:
grads_group = [grads]
else:
grads_group = grads
if output_params is None:
output_params_group = [None]*len(self.param_groups)
elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0])!=list:
output_params_group = [output_params]
else:
output_params_group = output_params
if grad_norms is None:
grad_norms = [None]*len(self.param_groups)
for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):
if grads_this_group is None:
grads_this_group = [None]*len(group['params'])
if output_params_this_group is None:
output_params_this_group = [None]*len(group['params'])
# compute combined scale factor for this group # assume same step across group now to simplify things
combined_scale = scale # per parameter step can be easily support by making it tensor, or pass list into kernel
if group['max_grad_norm'] > 0: if 'step' in group:
# norm is in fact norm*scale group['step'] += 1
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm'] else:
if clip > 1: group['step'] = 1
combined_scale = clip * scale
bias_correction = 1 if group['bias_correction'] else 0 # create lists for multi-tensor apply
p_list, g_list, m1_list, m2_list = [], [], [], []
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): for p in group['params']:
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients if p.grad is None:
if p.grad is None and grad is None:
continue continue
if grad is None: if p.grad.data.is_sparse:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead') raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] p_list.append(p.data)
beta1, beta2 = group['betas'] g_list.append(p.grad.data)
m1_list.append(state['exp_avg'])
m2_list.append(state['exp_avg_sq'])
state['step'] += 1 multi_tensor_applier(multi_tensor_adam,
self.dummy_overflow_buf,
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param [g_list, p_list, m1_list, m2_list],
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'], group['lr'],
beta1, beta1,
beta2, beta2,
group['eps'], group['eps'],
combined_scale, group['step'],
state['step'], self.adam_w_mode,
self.eps_mode,
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
return loss return loss
import torch
from apex.multi_tensor_apply import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
"""
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedLAMB, self).zero_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
for p in group['params']:
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
v_16.append(state['exp_avg_sq'])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
[g_16, p_16, m_16, v_16],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
group['max_grad_norm'])
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
[g_32, p_32, m_32, v_32],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
group['max_grad_norm'])
return loss
import torch
from apex.multi_tensor_apply import multi_tensor_applier
class FusedNovoGrad(torch.optim.Optimizer):
"""Implements NovoGrad algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Jasper: An End-to-End Convolutional Neural Acoustic Model`_.
More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and
only do it on update term. (default: False)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
norm_type (int, optional): which norm to calculate for each layer.
2 for L2 norm, and 0 for infinite norm. These 2 are only supported
type now. (default: 2)
init_zero (bool, optional): whether init norm with 0 (start averaging on
1st step) or first step norm (start averaging on 2nd step). True for
init with 0. (default: False)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
.. _Jasper\: An End-to-End Convolutional Neural Acoustic Mode:
https://arxiv.org/abs/1904.03288
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-8, weight_decay=0.,
amsgrad=False, reg_inside_moment=False,
grad_averaging=True, norm_type=2, init_zero=False,
set_grad_none=True):
if amsgrad:
raise RuntimeError('FusedNovoGrad does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, norm_type=norm_type,
init_zero=init_zero)
super(FusedNovoGrad, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_novograd = amp_C.multi_tensor_novograd
else:
raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions')
self.moment_mode = 0 if reg_inside_moment else 1
self.set_grad_none = set_grad_none
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedNovoGrad, self).zero_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
# create lists for multi-tensor apply
g_16, p_16, m_16 = [], [], []
g_32, p_32, m_32 = [], [], []
for p in group['params']:
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError('FusedNovoGrad does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
if p.dtype == torch.float16:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
else:
raise RuntimeError('FusedNovoGrad only support fp16 and fp32.')
# we store per weight norm as one tensor for one group/precision combination
# different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
if 'exp_avg_sq' not in group:
group['exp_avg_sq'] = [None, None]
if group['init_zero']:
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16)).contiguous().fill_(0)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32)).contiguous().fill_(0)
else: # init with first step norm, so first blend have no effect
if group['norm_type'] == 0:
v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16]
v_32 = [torch.max(torch.abs(g)).item() for g in g_32]
elif group['norm_type'] == 2:
v_16 = [torch.sum(torch.pow(g.to(torch.float32), 2)).sqrt().item() for g in g_16]
v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]
else:
raise RuntimeError('FusedNovoGrad only support l2/inf norm now.')
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32)
else:
assert(len(g_16) == group['exp_avg_sq'][0].numel())
assert(len(g_32) == group['exp_avg_sq'][1].numel())
if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_novograd,
self._dummy_overflow_buf,
[g_16, p_16, m_16],
group['exp_avg_sq'][0],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.moment_mode,
group['norm_type'])
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_novograd,
self._dummy_overflow_buf,
[g_32, p_32, m_32],
group['exp_avg_sq'][1],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.moment_mode,
group['norm_type'])
return loss
import torch
from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
explicit_master_params = (hasattr(self, "_amp_stash") and
hasattr(self._amp_stash, "fp32_from_fp16_groups"))
for gid, group in enumerate(self.param_groups):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
first_runs = [True, True]
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if explicit_master_params:
stash = self._amp_stash
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if self.materialize_master_grads:
fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp16_model_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]]
for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
launch_set,
weight_decay,
momentum,
dampening,
group['lr'],
nesterov,
first_run,
self.wd_after_momentum,
1.0/self.most_recent_scale)
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss
This diff is collapsed.
...@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm): ...@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda() >>> inp = torch.randn(10, 14, 14, 100).cuda()
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group self.process_group = process_group
self.channel_last = channel_last self.channel_last = channel_last
self.fuse_relu = fuse_relu
def _specify_process_group(self, process_group): def _specify_process_group(self, process_group):
self.process_group = process_group self.process_group = process_group
...@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm): ...@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm):
def _specify_channel_last(self, channel_last): def _specify_channel_last(self, channel_last):
self.channel_last = channel_last self.channel_last = channel_last
def forward(self, input): def forward(self, input, z = None):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing # if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last: if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference # fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
...@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm): ...@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked) exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: else:
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last) return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
...@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp ...@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False): def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):
torch.cuda.nvtx.range_push("sync_BN_fw") torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous() input = input.contiguous()
world_size = 0 world_size = 0
...@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function): ...@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps) inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std) ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.process_group = process_group ctx.process_group = process_group
ctx.channel_last = channel_last ctx.channel_last = channel_last
ctx.world_size = world_size ctx.world_size = world_size
ctx.fuse_relu = fuse_relu
if channel_last: if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias) out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)
else: else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias) out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
...@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function): ...@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path. # mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0) # mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0) # var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std = ctx.saved_tensors saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
process_group = ctx.process_group process_group = ctx.process_group
channel_last = ctx.channel_last channel_last = ctx.channel_last
world_size = ctx.world_size world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None fuse_relu = ctx.fuse_relu
grad_input = grad_z = grad_weight = grad_bias = None
if fuse_relu:
grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output? # TODO(jie): why do I have to clone here? life time of grad_output?
if channel_last: if channel_last:
...@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function): ...@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function):
else: else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu) grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]: if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None grad_weight = None
if weight is None or not ctx.needs_input_grad[2]: if weight is None or not ctx.needs_input_grad[3]:
grad_bias = None grad_bias = None
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
...@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda( ...@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
void multi_tensor_axpby_cuda( void multi_tensor_axpby_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
...@@ -40,9 +53,55 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -40,9 +53,55 @@ void multi_tensor_lamb_stage2_cuda(
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float step_size); const float step_size);
void multi_tensor_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay);
void multi_tensor_novograd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor grad_norms,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
const int norm_type);
void multi_tensor_lamb_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
const float max_grad_norm);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda, m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors"); "out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
...@@ -51,4 +110,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -51,4 +110,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Computes update part of LAMB optimizer"); "Computes update part of LAMB optimizer");
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda, m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
"Completes application of gradient to parameters for LAMB optimizer"); "Completes application of gradient to parameters for LAMB optimizer");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer");
} }
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input); ...@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype) // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::optional<at::Tensor> weight, const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift); const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias} // backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type; // grad_output/input should have identical data type;
...@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, ...@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu); const at::Tensor mean_dy_xmu);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance"); m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance"); m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
...@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc"); m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc"); m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc"); m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
} }
...@@ -128,3 +128,53 @@ __device__ __forceinline__ T reduce_block_into_lanes ...@@ -128,3 +128,53 @@ __device__ __forceinline__ T reduce_block_into_lanes
return final; return final;
} }
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment