You need to sign in or sign up before continuing.
Unverified Commit 7e7b0a8d authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

Add files via upload

Lamb CUDA Kernels
parent c04ae78a
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void fused_lamb_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g,
float lr, float beta1, float beta2, float max_coeff, float min_coeff, float eps, float grad_scale, int step, int mode, int bias_correction, float decay,
at::Tensor & w_l2_i, at::Tensor & u_l2_i, at::Tensor & lamb_coeff_val );
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
at::Tensor lamb(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float max_coeff, float min_coeff, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
//intermediate for weight L2 reduction
//make sure that the threads per block is at least 512 during the kernel launch otherwise the behavious is unexpected
at::Tensor w_l2_i = at::empty({512}, p.options().dtype(p.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : p.type().scalarType()));
//intermediate for update L2 reduction
//make sure that the threads per block is at least 512 during the kernel launch otherwise the behavious is unexpected
at::Tensor u_l2_i = at::empty({512}, p.options().dtype(p.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : p.type().scalarType()));
at::Tensor lamb_coeff_val = at::empty({1}, p.options().dtype(p.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : p.type().scalarType()));
fused_lamb_cuda(p, p_copy, m, v, g, lr, beta1, beta2, max_coeff, min_coeff, eps, grad_scale, step, mode, bias_correction, decay, w_l2_i, u_l2_i, lamb_coeff_val);
return lamb_coeff_val;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("lamb", &lamb, "Adam optimized CUDA implementation with LAMB.");
}
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include <iostream>
//#include <helper_functions.h>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory
{
// Ensure that we won't compile any un-specialized types
__device__ inline operator T *()
{
extern __device__ void error(void);
error();
return NULL;
}
};
template <>
struct SharedMemory <float>
{
__device__ inline operator float *()
{
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory <double>
{
__device__ inline operator double *()
{
extern __shared__ double s_double[];
return s_double;
}
};
}
#include "type_shim.h"
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
//s_a and s_b are in shared memory
//g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void
reduce_block_in_shared_memory(T *s_a, T *s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300 )
if ( tid < 32 )
{
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64)
{
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize/2; offset > 0; offset /= 2)
{
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
#endif
// write result for this block to global mem
if (tid == 0){
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b){
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T *s_a = SharedMemory<T>();
T *s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T,blockSize>(s_a, s_b ,g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
T pj = p[j];
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j]/denom) + (decay*p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T,blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(
const size_t tsize,
T* __restrict__ g_a,
T* __restrict__ g_b)
{
T *s_a = SharedMemory<T>() ;
T *s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize){
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T,blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w !=0 and reg_u !=0){
lamb_coeff = reg_w/reg_u;
if (lamb_coeff > max_coeff){
lamb_coeff = max_coeff;
}
if (lamb_coeff < min_coeff){
lamb_coeff = min_coeff;
}
}
if(blockId == 0 and threadIdInBlock == 0)
{
lamb_coeff_val[0] = lamb_coeff;
//printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j+=totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj/denom) + (decay*pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T) pj;
}
}
void fused_lamb_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor & w_l2_i,
at::Tensor & u_l2_i,
at::Tensor & lamb_coeff)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize+threadsPerBlock-1)/threadsPerBlock;
if (num_blocks > 512) num_blocks=512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock><<<blocks,threadsPerBlock, smemsize, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock><<<1,threadsPerBlock, smemsize, stream>>>(
num_blocks,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
lamb_cuda_kernel_part3<accscalar_t, scalar_t><<<blocks,threadsPerBlock, smemsize, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "lamb_cuda_kernel", ([&] {
lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock><<<blocks,threadsPerBlock, smemsize, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock><<<1,threadsPerBlock, smemsize, stream>>>(
num_blocks,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
lamb_cuda_kernel_part3<scalar_t, scalar_t><<<blocks,threadsPerBlock, smemsize, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
THCudaCheck(cudaGetLastError());
}
//template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a, float* g_b, cg::grid_group &cgg);
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment