Commit 4d6ed501 authored by Deyu Fu's avatar Deyu Fu
Browse files

Merge branch 'multi_tensor_sgd' into deyuf/fused_optimizer_v2

parents 690b1f71 9f64bf27
import torch
import xentropy_cuda
class SoftmaxCrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
losses, max_log_sum_exp = xentropy_cuda.forward(
logits, labels, smoothing, half_to_float)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels,
torch.FloatTensor([smoothing]),
torch.LongTensor([padding_idx]))
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==padding_idx.item(), 0)
grad_logits = xentropy_cuda.backward(
grad_loss.contiguous(), logits, max_log_sum_exp,
labels, smoothing.item())
return grad_logits, None, None, None, None
from .sgd import FusedSGD from .fused_sgd import FusedSGD
from .novograd import FusedNovoGrad from .novograd import FusedNovoGrad
from .fused_adam_v1 import FusedAdam_v1 from .fused_adam_v1 import FusedAdam_v1
from .adam import FusedAdam from .adam import FusedAdam
#from .sgd import FusedSGD
from .fp16_optimizer import FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
...@@ -2,8 +2,9 @@ import types ...@@ -2,8 +2,9 @@ import types
import torch import torch
import importlib import importlib
class FusedAdam_v1(torch.optim.Optimizer): from ..multi_tensor_apply import multi_tensor_applier
class FusedAdam_v1(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
...@@ -25,6 +26,8 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -25,6 +26,8 @@ class FusedAdam_v1(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -35,10 +38,21 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -35,10 +38,21 @@ class FusedAdam_v1(torch.optim.Optimizer):
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False): weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._use_multi_tensor = False
if use_mt:
if not multi_tensor_applier.available:
print("Warning: multi_tensor_applier is unavailable")
else:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0])
self._amp_scale_adjustment = amp_scale_adjustment
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
...@@ -66,6 +80,12 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -66,6 +80,12 @@ class FusedAdam_v1(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale*self._amp_scale_adjustment
grad_norms = self._amp_stash.grad_norms
if grads is None: if grads is None:
grads_group = [None]*len(self.param_groups) grads_group = [None]*len(self.param_groups)
# backward compatibility # backward compatibility
...@@ -105,6 +125,12 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -105,6 +125,12 @@ class FusedAdam_v1(torch.optim.Optimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
if self._use_multi_tensor:
if output_params:
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None: if p.grad is None and grad is None:
...@@ -130,18 +156,43 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -130,18 +156,43 @@ class FusedAdam_v1(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data, if self._use_multi_tensor:
out_p, pl = [p.data, exp_avg, exp_avg_sq, grad]
exp_avg, if output_param is not None:
exp_avg_sq, pl.append(out_p)
grad,
group['lr'], for tl, t in zip(tensorlists, pl):
beta1, tl.append(t)
beta2, else:
group['eps'], fused_adam_cuda.adam(p.data,
combined_scale, out_p,
state['step'], exp_avg,
self.eps_mode, exp_avg_sq,
bias_correction, grad,
group['weight_decay']) group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
import torch
from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
explicit_master_params = (hasattr(self, "_amp_stash") and
hasattr(self._amp_stash, "fp32_from_fp16_groups"))
for gid, group in enumerate(self.param_groups):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
first_runs = [True, True]
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if explicit_master_params:
stash = self._amp_stash
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if self.materialize_master_grads:
fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp16_model_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]]
for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
launch_set,
weight_decay,
momentum,
dampening,
group['lr'],
nesterov,
first_run,
self.wd_after_momentum,
1.0/self.most_recent_scale)
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss
This diff is collapsed.
...@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm): ...@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda() >>> inp = torch.randn(10, 14, 14, 100).cuda()
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group self.process_group = process_group
self.channel_last = channel_last self.channel_last = channel_last
self.fuse_relu = fuse_relu
def _specify_process_group(self, process_group): def _specify_process_group(self, process_group):
self.process_group = process_group self.process_group = process_group
...@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm): ...@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm):
def _specify_channel_last(self, channel_last): def _specify_channel_last(self, channel_last):
self.channel_last = channel_last self.channel_last = channel_last
def forward(self, input): def forward(self, input, z = None):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing # if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last: if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference # fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
...@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm): ...@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked) exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: else:
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last) return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
...@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp ...@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False): def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):
torch.cuda.nvtx.range_push("sync_BN_fw") torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous() input = input.contiguous()
world_size = 0 world_size = 0
...@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function): ...@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps) inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std) ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.process_group = process_group ctx.process_group = process_group
ctx.channel_last = channel_last ctx.channel_last = channel_last
ctx.world_size = world_size ctx.world_size = world_size
ctx.fuse_relu = fuse_relu
if channel_last: if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias) out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)
else: else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias) out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
...@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function): ...@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path. # mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0) # mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0) # var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std = ctx.saved_tensors saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
process_group = ctx.process_group process_group = ctx.process_group
channel_last = ctx.channel_last channel_last = ctx.channel_last
world_size = ctx.world_size world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None fuse_relu = ctx.fuse_relu
grad_input = grad_z = grad_weight = grad_bias = None
if fuse_relu:
grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output? # TODO(jie): why do I have to clone here? life time of grad_output?
if channel_last: if channel_last:
...@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function): ...@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function):
else: else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu) grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]: if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None grad_weight = None
if weight is None or not ctx.needs_input_grad[2]: if weight is None or not ctx.needs_input_grad[3]:
grad_bias = None grad_bias = None
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
...@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda( ...@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
void multi_tensor_axpby_cuda( void multi_tensor_axpby_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
...@@ -72,6 +85,8 @@ void multi_tensor_novograd_cuda( ...@@ -72,6 +85,8 @@ void multi_tensor_novograd_cuda(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda, m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors"); "out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
} }
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
// #include "ATen/Type.h" // #include "ATen/Type.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h" #include "type_shim.h"
...@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel( ...@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel(
} }
} }
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
};
void fused_adam_cuda( void fused_adam_cuda(
at::Tensor & p, at::Tensor & p,
at::Tensor & p_copy, at::Tensor & p_copy,
...@@ -96,7 +187,7 @@ void fused_adam_cuda( ...@@ -96,7 +187,7 @@ void fused_adam_cuda(
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(), p.data<accscalar_t>(),
...@@ -112,7 +203,7 @@ void fused_adam_cuda( ...@@ -112,7 +203,7 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
) );
} else { } else {
using namespace at; using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
...@@ -135,3 +226,110 @@ void fused_adam_cuda( ...@@ -135,3 +226,110 @@ void fused_adam_cuda(
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
} else {
if (tl_sz == 5) {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
}
THCudaCheck(cudaGetLastError());
}
This diff is collapsed.
...@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input); ...@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype) // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::optional<at::Tensor> weight, const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift); const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias} // backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type; // grad_output/input should have identical data type;
...@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, ...@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu); const at::Tensor mean_dy_xmu);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance"); m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance"); m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
...@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc"); m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc"); m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc"); m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
} }
...@@ -591,6 +591,58 @@ template < ...@@ -591,6 +591,58 @@ template <
int PARALLEL_LOADS> int PARALLEL_LOADS>
__global__ void batchnorm_forward_c_last_kernel( __global__ void batchnorm_forward_c_last_kernel(
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int reduction_size,
const int stride,
const bool fuse_relu) {
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
for (int i = 0; i < loop_count; i++) {
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
if (z != NULL) {
tmp += z[address_base];
}
out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
}
}
// elementwise BN kernel
template <
typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int PARALLEL_LOADS>
__global__ void relu_backward_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std, const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight, const layerscalar_t* __restrict__ weight,
...@@ -619,9 +671,11 @@ __global__ void batchnorm_forward_c_last_kernel( ...@@ -619,9 +671,11 @@ __global__ void batchnorm_forward_c_last_kernel(
#pragma unroll #pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) { for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) { if (c_offset < stride && m_offset < reduction_size) {
out[address_base] = static_cast<scalar_t>( auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c if (z != NULL) {
); tmp += z[address_base];
}
out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);
} }
m_offset += inner_loop_stride; m_offset += inner_loop_stride;
address_base += address_increment; address_base += address_increment;
...@@ -1147,10 +1201,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) { ...@@ -1147,10 +1201,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at::Tensor batchnorm_forward_c_last_CUDA( at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input, const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::optional<at::Tensor> weight, const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) { const at::optional<at::Tensor> shift,
const bool fuse_relu) {
const auto stride = input.size(input.ndimension()-1); const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride; const auto reduction_size = input.numel() / stride;
...@@ -1170,13 +1226,15 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1170,13 +1226,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(), input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL, shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(), out.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride,
fuse_relu);
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
...@@ -1189,13 +1247,15 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1189,13 +1247,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(), input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL, shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(), out.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride,
fuse_relu);
); );
} }
return out; return out;
...@@ -1351,3 +1411,66 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1351,3 +1411,66 @@ at::Tensor batchnorm_backward_c_last_CUDA(
return grad_input; return grad_input;
} }
at::Tensor relu_backward_c_last_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor out = at::empty_like(input);
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
} else {
if (weight.has_value()) {
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
}
return out;
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment