Commit be42aad5 authored by Deyu Fu's avatar Deyu Fu
Browse files

WIP: improve fused adam

parent b436213e
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fp16_optimizer import FP16_Optimizer
#include <torch/extension.h> #include <torch/extension.h>
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface // C++ interface
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode) { void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p) CHECK_INPUT(p)
if (p_copy.numel() > 0) CHECK_INPUT(p_copy); if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m); CHECK_INPUT(m);
...@@ -20,7 +20,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -20,7 +20,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode); fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -28,7 +28,8 @@ __global__ void adam_cuda_kernel( ...@@ -28,7 +28,8 @@ __global__ void adam_cuda_kernel(
const float grad_scale, const float grad_scale,
const float step_size, const float step_size,
const size_t tsize, const size_t tsize,
adamMode_t mode) { adamMode_t mode,
const float decay) {
//Assuming 2D grids and 2D blocks //Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
...@@ -46,7 +47,8 @@ __global__ void adam_cuda_kernel( ...@@ -46,7 +47,8 @@ __global__ void adam_cuda_kernel(
denom = sqrtf(v[j] + eps); denom = sqrtf(v[j] + eps);
else // Mode 1 else // Mode 1
denom = sqrtf(v[j]) + eps; denom = sqrtf(v[j]) + eps;
p[j] = p[j] - (step_size*m[j]/denom); float update = (m[j]/denom) + (decay*p[j]);
p[j] = p[j] - (step_size*update);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j]; if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
} }
} }
...@@ -63,7 +65,9 @@ void fused_adam_cuda( ...@@ -63,7 +65,9 @@ void fused_adam_cuda(
float eps, float eps,
float grad_scale, float grad_scale,
int step, int step,
int mode) { int mode,
int bias_correction,
float decay) {
//Get tensor size //Get tensor size
int tsize = p.numel(); int tsize = p.numel();
...@@ -72,9 +76,15 @@ void fused_adam_cuda( ...@@ -72,9 +76,15 @@ void fused_adam_cuda(
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants //Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step); const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step); const float bias_correction2 = 1 - std::pow(beta2, step);
const float step_size = lr * std::sqrt(bias_correction2)/bias_correction1; step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) { if (g.type().scalarType() == at::ScalarType::Half) {
...@@ -95,7 +105,8 @@ void fused_adam_cuda( ...@@ -95,7 +105,8 @@ void fused_adam_cuda(
grad_scale, grad_scale,
step_size, step_size,
tsize, tsize,
(adamMode_t) mode); (adamMode_t) mode,
decay);
})); }));
} else { } else {
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
...@@ -111,7 +122,8 @@ void fused_adam_cuda( ...@@ -111,7 +122,8 @@ void fused_adam_cuda(
grad_scale, grad_scale,
step_size, step_size,
tsize, tsize,
(adamMode_t) mode); (adamMode_t) mode,
decay);
})); }));
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
......
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Design to be used in the same way but support only fused optimizers in apex.
Refer to apex.fp16_utils documents for more information.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = apex.optimizers.FusedAdam(model.parameters())
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...
Example with dynamic loss scaling::
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
# dynamic_loss_args={'scale_window' : 500})
# Usually, dynamic_loss_args is not necessary.
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True):
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add new fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
self.fp32_groups_flat = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# init fp16 weight buffer, flattened
self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
for p,q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
# init master weight, flattened
self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
# modify optimizer of have flat master weight
self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.fp32_groups_flat[i]]
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is not None:
raise SystemError("Do not support dynamic loss scale args for now.")
self.dynamic_loss_scale = True
self.cur_scale = 2**32
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2
self.scale_window = 1000
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
self.cur_scale = static_loss_scale
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _compute_grad_norm(self, fp16_grads_flat, norm_type=2):
"""
Compute fp16 grad norm for later clipping(fused with update).
Internal accumulated in fp32.
Also fused in NaN check. Possibly other reduction needed for grad.
Args:
fp16_grads_flat (tensor): fp16 grad flattened
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the current fp16 gradients (viewed as a single vector).
Returns -1 if the most recently computed fp16 gradients overflowed
"""
# TODO: currently using pre-1.0 api, and not most efficient with copy to cpu and sync
norm = float(torch.norm(fp16_grads_flat, p=norm_type))
if norm == float('inf') or norm == -float('inf') or norm != norm:
return -1
else:
return norm
def step(self, closure=None):
"""
Not supporting closure.
"""
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
norm_groups = []
skip = False
for i, group in enumerate(self.fp16_groups):
grads_groups_flat.append(_flatten_dense_tensors([p.grad for p in group]))
norm_groups.append(self._compute_grad_norm(grads_groups_flat[i]))
if norm_groups[i] == -1: #TODO: early break
skip = True
if skip:
self._update_scale(skip)
return
# norm is in fact norm*cur_scale
self.optimizer.step(grads_group=[[g] for g in grads_groups_flat],
output_params_group=[[p] for p in self.fp16_groups_flat],
scale=self.cur_scale,
grad_norms=norm_groups)
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
for p,q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
self._update_scale(False)
return
def backward(self, loss):
"""
:attr:`backward` performs the following conceptual steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
scaled_loss = (loss.float()) * self.cur_scale
scaled_loss.backward()
def _update_scale(self, skip):
if self.dynamic_loss_scale:
if skip:
print("grad overflow on iteration", self.cur_iter)
print("Using dynamic loss scale of", self.cur_scale)
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
self.cur_scale *= self.scale_factor
else:
if skip:
print("Grad overflow on iteration", self.cur_iter)
print("Using static loss scale of", self.cur_scale)
self.cur_iter +=1
return
import math
import torch import torch
import fused_adam_cuda import fused_adam_cuda
class FusedAdam(torch.optim.Adam): def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0 - x
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
}
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
...@@ -31,14 +53,19 @@ class FusedAdam(torch.optim.Adam): ...@@ -31,14 +53,19 @@ class FusedAdam(torch.optim.Adam):
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(self, params,
weight_decay=0, amsgrad=False, eps_inside_sqrt = False): lr=1e-3, warmup=-1, t_total=-1, schedule='warmup_linear',
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False):
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
super(FusedAdam, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
def step(self, closure=None, grads=None, output_params=None, scale=1.): def step(self, closure=None, grads_group=None, output_params_group=None, scale=1., grad_norms=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -56,14 +83,29 @@ class FusedAdam(torch.optim.Adam): ...@@ -56,14 +83,29 @@ class FusedAdam(torch.optim.Adam):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if grads is not None:
assert len(self.param_groups)==1, "mixed precision optimizer works for a single group only" if grads_group is None:
for group in self.param_groups: grads_group = [None]*len(self.param_groups)
if output_params_group is None:
output_params_group = [None]*len(self.param_groups)
if grad_norms is None:
grad_norms = [None]*len(self.param_groups)
for group, grads, output_params, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):
if grads is None: if grads is None:
grads = [None]*len(group['params']) grads = [None]*len(group['params'])
if output_params is None: if output_params is None:
output_params = [None]*len(group['params']) output_params = [None]*len(group['params'])
for p, grad, output_param in zip(group['params'],grads, output_params):
# compute combined scale factor for this group
combined_scale = scale
if group['max_grad_norm'] > 0:
# norm is in fact norm*scale
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
if clip > 1:
combined_scale = clip * scale
for p, grad, output_param in zip(group['params'], grads, output_params):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None: if p.grad is None and grad is None:
continue continue
...@@ -85,7 +127,16 @@ class FusedAdam(torch.optim.Adam): ...@@ -85,7 +127,16 @@ class FusedAdam(torch.optim.Adam):
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
bias_correction = 0
else:
lr_scheduled = group['lr']
bias_correction = 1
state['step'] += 1 state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data, fused_adam_cuda.adam(p.data,
out_p, out_p,
...@@ -96,8 +147,9 @@ class FusedAdam(torch.optim.Adam): ...@@ -96,8 +147,9 @@ class FusedAdam(torch.optim.Adam):
beta1, beta1,
beta2, beta2,
group['eps'], group['eps'],
scale, combined_scale,
state['step'], state['step'],
self.eps_mode) self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment