Unverified Commit b56a2088 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #92 from FDecaYed/deyuf/fused_adam

WIP: improve fused adam
parents 3f9b5c98 b6188fc4
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fp16_optimizer import FP16_Optimizer
#include <torch/extension.h> #include <torch/extension.h>
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface // C++ interface
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode) { void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p) CHECK_INPUT(p)
if (p_copy.numel() > 0) CHECK_INPUT(p_copy); if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m); CHECK_INPUT(m);
...@@ -20,7 +20,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -20,7 +20,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode); fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -28,7 +28,8 @@ __global__ void adam_cuda_kernel( ...@@ -28,7 +28,8 @@ __global__ void adam_cuda_kernel(
const float grad_scale, const float grad_scale,
const float step_size, const float step_size,
const size_t tsize, const size_t tsize,
adamMode_t mode) { adamMode_t mode,
const float decay) {
//Assuming 2D grids and 2D blocks //Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
...@@ -46,7 +47,8 @@ __global__ void adam_cuda_kernel( ...@@ -46,7 +47,8 @@ __global__ void adam_cuda_kernel(
denom = sqrtf(v[j] + eps); denom = sqrtf(v[j] + eps);
else // Mode 1 else // Mode 1
denom = sqrtf(v[j]) + eps; denom = sqrtf(v[j]) + eps;
p[j] = p[j] - (step_size*m[j]/denom); float update = (m[j]/denom) + (decay*p[j]);
p[j] = p[j] - (step_size*update);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j]; if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
} }
} }
...@@ -63,7 +65,9 @@ void fused_adam_cuda( ...@@ -63,7 +65,9 @@ void fused_adam_cuda(
float eps, float eps,
float grad_scale, float grad_scale,
int step, int step,
int mode) { int mode,
int bias_correction,
float decay) {
//Get tensor size //Get tensor size
int tsize = p.numel(); int tsize = p.numel();
...@@ -72,9 +76,15 @@ void fused_adam_cuda( ...@@ -72,9 +76,15 @@ void fused_adam_cuda(
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants //Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step); const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step); const float bias_correction2 = 1 - std::pow(beta2, step);
const float step_size = lr * std::sqrt(bias_correction2)/bias_correction1; step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) { if (g.type().scalarType() == at::ScalarType::Half) {
...@@ -95,7 +105,8 @@ void fused_adam_cuda( ...@@ -95,7 +105,8 @@ void fused_adam_cuda(
grad_scale, grad_scale,
step_size, step_size,
tsize, tsize,
(adamMode_t) mode); (adamMode_t) mode,
decay);
})); }));
} else { } else {
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
...@@ -111,7 +122,8 @@ void fused_adam_cuda( ...@@ -111,7 +122,8 @@ void fused_adam_cuda(
grad_scale, grad_scale,
step_size, step_size,
tsize, tsize,
(adamMode_t) mode); (adamMode_t) mode,
decay);
})); }));
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
......
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import ctypes
lib = ctypes.cdll.LoadLibrary(None)
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p]
lib.THCudaHalfTensor_normall.restype = ctypes.c_float
def fused_norm(input):
if input.type() == 'torch.cuda.HalfTensor':
# 16384 is half 2 if you stare at it long enough
return lib.THCudaHalfTensor_normall(torch.cuda._state_cdata,
input._cdata, 16384)
else:
return input.norm()
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Design to be used in the same way but support only fused optimizers in apex.
Refer to apex.fp16_utils documents for more information.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = apex.optimizers.FusedAdam(model.parameters())
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...
Example with dynamic loss scaling::
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
# dynamic_loss_args={'scale_window' : 500})
# Usually, dynamic_loss_args is not necessary.
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True):
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add new fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
self.fp32_groups_flat = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# init fp16 weight buffer, flattened
self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
for p,q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
# init master weight, flattened
self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
# modify optimizer of have flat master weight
self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.fp32_groups_flat[i]]
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is not None:
raise SystemError("Do not support dynamic loss scale args for now.")
self.dynamic_loss_scale = True
self.cur_scale = 2**32
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2
self.scale_window = 1000
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
self.cur_scale = static_loss_scale
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _compute_grad_norm(self, fp16_grads_flat, norm_type=2):
"""
Compute fp16 grad norm for later clipping(fused with update).
Internal accumulated in fp32.
Also fused in NaN check. Possibly other reduction needed for grad.
Args:
fp16_grads_flat (tensor): fp16 grad flattened
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the current fp16 gradients (viewed as a single vector).
Returns -1 if the most recently computed fp16 gradients overflowed
"""
# TODO: currently using pre-1.0 api, and not most efficient with copy to cpu and sync
# only support 2-norm now
norm = float(fused_norm(fp16_grads_flat))
if norm == float('inf') or norm == -float('inf') or norm != norm:
return -1
else:
return norm
def step(self, closure=None):
"""
Not supporting closure.
"""
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
norm_groups = []
skip = False
for i, group in enumerate(self.fp16_groups):
grads_groups_flat.append(_flatten_dense_tensors([p.grad for p in group]))
norm_groups.append(self._compute_grad_norm(grads_groups_flat[i]))
if norm_groups[i] == -1: #TODO: early break
skip = True
if skip:
self._update_scale(skip)
return
# norm is in fact norm*cur_scale
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
output_params=[[p] for p in self.fp16_groups_flat],
scale=self.cur_scale,
grad_norms=norm_groups)
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
for p,q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
self._update_scale(False)
return
def backward(self, loss):
"""
:attr:`backward` performs the following conceptual steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
scaled_loss = (loss.float()) * self.cur_scale
scaled_loss.backward()
def _update_scale(self, skip):
if self.dynamic_loss_scale:
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using dynamic loss scale of", self.cur_scale)
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
self.cur_scale *= self.scale_factor
else:
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using static loss scale of", self.cur_scale)
self.cur_iter +=1
return
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
import types
import torch import torch
import fused_adam_cuda import fused_adam_cuda
class FusedAdam(torch.optim.Adam): class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
...@@ -31,14 +32,19 @@ class FusedAdam(torch.optim.Adam): ...@@ -31,14 +32,19 @@ class FusedAdam(torch.optim.Adam):
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(self, params,
weight_decay=0, amsgrad=False, eps_inside_sqrt = False): lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False):
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
super(FusedAdam, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
def step(self, closure=None, grads=None, output_params=None, scale=1.): def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -56,14 +62,47 @@ class FusedAdam(torch.optim.Adam): ...@@ -56,14 +62,47 @@ class FusedAdam(torch.optim.Adam):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if grads is not None:
assert len(self.param_groups)==1, "mixed precision optimizer works for a single group only"
for group in self.param_groups:
if grads is None: if grads is None:
grads = [None]*len(group['params']) grads_group = [None]*len(self.param_groups)
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0])!=list:
grads_group = [grads]
else:
grads_group = grads
if output_params is None: if output_params is None:
output_params = [None]*len(group['params']) output_params_group = [None]*len(self.param_groups)
for p, grad, output_param in zip(group['params'],grads, output_params): elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0])!=list:
output_params_group = [output_params]
else:
output_params_group = output_params
if grad_norms is None:
grad_norms = [None]*len(self.param_groups)
for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):
if grads_this_group is None:
grads_this_group = [None]*len(group['params'])
if output_params_this_group is None:
output_params_this_group = [None]*len(group['params'])
# compute combined scale factor for this group
combined_scale = scale
if group['max_grad_norm'] > 0:
# norm is in fact norm*scale
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
if clip > 1:
combined_scale = clip * scale
bias_correction = 1 if group['bias_correction'] else 0
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None: if p.grad is None and grad is None:
continue continue
...@@ -86,6 +125,7 @@ class FusedAdam(torch.optim.Adam): ...@@ -86,6 +125,7 @@ class FusedAdam(torch.optim.Adam):
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
state['step'] += 1 state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data, fused_adam_cuda.adam(p.data,
out_p, out_p,
...@@ -96,8 +136,9 @@ class FusedAdam(torch.optim.Adam): ...@@ -96,8 +136,9 @@ class FusedAdam(torch.optim.Adam):
beta1, beta1,
beta2, beta2,
group['eps'], group['eps'],
scale, combined_scale,
state['step'], state['step'],
self.eps_mode) self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment