"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "e57f5d0e69359dbff9c2e23b49275820970fc55d"
Commit c8f9cceb authored by Deyu Fu's avatar Deyu Fu
Browse files

add fused lamb, put lamb kernels into one file

parent aee5aff4
from .fused_sgd import FusedSGD from .fused_sgd import FusedSGD
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fused_novograd import FusedNovoGrad from .fused_novograd import FusedNovoGrad
from .fused_lamb import FusedLAMB
from .fp16_optimizer import FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
import torch
from apex.multi_tensor_apply import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and
only do it on update term. (default: False)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
"""
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, reg_inside_moment=False,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
self.moment_mode = 0 if reg_inside_moment else 1
self.set_grad_none = set_grad_none
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedLAMB, self).zero_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
for p in group['params']:
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
v_16.append(state['exp_avg_sq'])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
[g_16, p_16, m_16, v_16],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.moment_mode,
group['max_grad_norm'])
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
[g_32, p_32, m_32, v_32],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.moment_mode,
group['max_grad_norm'])
return loss
...@@ -33,44 +33,39 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -33,44 +33,39 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python); at::optional<bool> per_tensor_python);
void multi_tensor_lamb_stage1_cuda( void multi_tensor_adam_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_decay, const float lr,
const int step,
const float beta1, const float beta1,
const float beta2, const float beta2,
const float epsilon, const float epsilon,
const float global_grad_norm, const int step,
const float max_global_grad_norm); const int eps_mode,
const int bias_correction,
void multi_tensor_lamb_stage2_cuda( const float weight_decay);
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float step_size);
void multi_tensor_adam_cuda( void multi_tensor_novograd_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor grad_norms,
const float lr, const float lr,
const float beta1, const float beta1,
const float beta2, const float beta2,
const float epsilon, const float epsilon,
const int step, const int step,
const int eps_mode,
const int bias_correction, const int bias_correction,
const float weight_decay); const float weight_decay,
const int grad_averaging,
const int moment_mode,
const int norm_type);
void multi_tensor_novograd_cuda( void multi_tensor_lamb_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor grad_norms,
const float lr, const float lr,
const float beta1, const float beta1,
const float beta2, const float beta2,
...@@ -80,7 +75,7 @@ void multi_tensor_novograd_cuda( ...@@ -80,7 +75,7 @@ void multi_tensor_novograd_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int moment_mode, const int moment_mode,
const int norm_type); const float max_grad_norm);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
...@@ -91,12 +86,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -91,12 +86,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors"); "out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors"); "Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
"Computes update part of LAMB optimizer");
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
"Completes application of gradient to parameters for LAMB optimizer");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda, m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"); "Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda, m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"); "Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer");
} }
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum{
MOMENT_MODE_0 =0, // Momentum with denom/decay, optional grad averaging after
MOMENT_MODE_1 =1 // Momentum without denom/decay
} momentMode_t;
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template<typename T>
struct LAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta3,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
momentMode_t m_mode,
const float decay,
float* global_grad_norm,
float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i];
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (m_mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = std::sqrt(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = std::sqrt(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
g[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T>
struct LAMBStage2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
MATH_T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_p[ILP];
MATH_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
}
}
}
}
};
void multi_tensor_lamb_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int moment_mode,
const float max_grad_norm)
{
using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute global grad norm
auto grad_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, false);
// Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0>(),
beta1,
beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1,
bias_correction2,
epsilon,
(momentMode_t) moment_mode,
weight_decay,
std::get<0>(grad_norm_tuple).data<float>(),
max_grad_norm); )
// Compute update norms
auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
grad_param_list,
LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).data<float>(),
std::get<1>(update_norm_tuple).data<float>(),
lr); )
AT_CUDA_CHECK(cudaGetLastError());
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
// Step 1 computes the 'update' value of regular Adam optimizer.
template<typename GRAD_T, typename T, typename UPD_T>
struct LAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<5>& tl,
const float* per_tensor_decay,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float clipped_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float decay = per_tensor_decay[tensor_num];
GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
UPD_T* update = (UPD_T*)tl.addresses[4][tensor_loc];
update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
GRAD_T r_g[ILP];
T r_p[ILP];
T r_m[ILP];
T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = GRAD_T(0);
r_p[ii] = T(0);
r_m[ii] = T(0);
r_v[ii] = T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + (1-beta1) * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
T next_m_unbiased = r_m[ii] / beta1_correction;
T next_v_unbiased = r_v[ii] / beta2_correction;
T denom = std::sqrt(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
update[i] = (UPD_T)r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_lamb_stage1_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_decay,
const int step,
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
const float max_global_grad_norm)
{
using namespace at;
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_decay.data<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
epsilon,
clipped_global_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T, typename UPD_T>
struct LAMBStage2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
UPD_T* update = (UPD_T*)tl.addresses[1][tensor_loc];
update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
T r_p[ILP];
UPD_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio*(T)r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
}
}
}
}
};
void multi_tensor_lamb_stage2_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.data<float>(),
per_tensor_update_norm.data<float>(),
learning_rate); ))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
...@@ -89,10 +89,9 @@ if "--cuda_ext" in sys.argv: ...@@ -89,10 +89,9 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_scale_kernel.cu', 'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu', 'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu', 'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_lamb_stage_1.cu',
'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu', 'csrc/multi_tensor_adam.cu',
'csrc/multi_tensor_novograd.cu'], 'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lamb.cu'],
extra_compile_args={'cxx': ['-O3'], extra_compile_args={'cxx': ['-O3'],
'nvcc':['-lineinfo', 'nvcc':['-lineinfo',
'-O3', '-O3',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment