Unverified Commit 55716d85 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

Merge pull request #761 from kexinyu/master

add additional loop for lists of params in FP16_Optimizer's load_state_dict 
parents 11faaca7 04de0f7a
#include <torch/extension.h>
void multi_tensor_lamb_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
const float global_grad_norm,
const float max_grad_norm);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum{
MOMENT_MODE_0 =0, // L2 regularization mode
MOMENT_MODE_1 =1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template<typename T>
struct LAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta3,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
adamMode_t mode,
const float decay,
const float global_grad_norm,
const float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i];
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
g[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T>
struct LAMBStage2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate,
const float decay)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T ratio = learning_rate;
// apply adaptive learning rate to parameters with non-zero weight decay
if (decay != 0.0)
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_p[ILP];
MATH_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
}
}
}
}
};
void multi_tensor_lamb_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
const float global_grad_norm,
const float max_grad_norm)
{
using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0>(),
beta1,
beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1,
bias_correction2,
epsilon,
(adamMode_t) mode,
weight_decay,
global_grad_norm,
max_grad_norm); )
// Compute update norms
auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
grad_param_list,
LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
lr,
weight_decay); )
AT_CUDA_CHECK(cudaGetLastError());
}
from .fp16_optimizer import FP16_Optimizer
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
......@@ -239,4 +239,5 @@ class FP16_Optimizer(object):
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current, saved in zip(self.fp32_groups, state_dict['fp32_groups']):
current.data.copy_(saved.data)
for _current, _saved in zip(current, saved):
_current.data.copy_(_saved.data)
import torch
import importlib
import math
from apex.multi_tensor_apply import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_lamb" ./``.
This version of fused LAMB implements 2 fusions.
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.contrib.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.contrib.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
.. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
fused_lamb_cuda = importlib.import_module("fused_lamb_cuda")
self.multi_tensor_lamb = fused_lamb_cuda.lamb
else:
raise RuntimeError('apex.contrib.optimizers.FusedLAMB requires cuda extensions')
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedLAMB, self).zero_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dytpe == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = 0.0, 0.0
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0].item()
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0].item()
# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
for p in group['params']:
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
v_16.append(state['exp_avg_sq'])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
[g_16, p_16, m_16, v_16],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
global_grad_norm,
max_grad_norm)
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
[g_32, p_32, m_32, v_32],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
global_grad_norm,
max_grad_norm)
return loss
......@@ -198,6 +198,26 @@ if "--deprecated_fused_adam" in sys.argv:
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_lamb")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment