"vscode:/vscode.git/clone" did not exist on "17e355f70267e0e01f7c4355a75d46b76f55b5aa"
Commit 2619f1cb authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Resolve merge conflict

parent 91a5a87e
......@@ -8,13 +8,10 @@ void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor
void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_maybe_adam_undo_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);
void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
void update_step_and_loss_scaler_cuda(at::Tensor & overflow_flag, at::Tensor & step_and_loss_scaler);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
......@@ -84,8 +81,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("maybe_adam_undo_mt", &fused_maybe_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
m.def("update_step_and_loss_scaler", &update_step_and_loss_scaler_cuda, "Update step and loss scaler.");
}
......@@ -14,6 +14,17 @@
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
#include "type_shim.h"
typedef enum{
......@@ -21,6 +32,359 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*p[j]);
p[j] = p[j] - (step_size*update);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(incoming_v[ii] + eps);
else // Mode 1
denom = sqrtf(incoming_v[ii]) + eps;
float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
}
else
{
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
} else {
if (tl_sz == 5) {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
}
THCudaCheck(cudaGetLastError());
}
template <typename FROM_T, typename TO_T>
__device__ void convert(const FROM_T vi, TO_T& vo)
{
......@@ -202,44 +566,6 @@ __global__ void maybe_cast_kernel(
}
}
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*p[j]);
p[j] = p[j] - (step_size*update);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}
template <typename T, typename GRAD_T, typename REDU_T>
__global__ void reversible_adam_cuda_kernel(
T* __restrict__ p,
......@@ -402,266 +728,53 @@ __global__ void maybe_adam_undo_cuda_kernel(
}
}
}
}
__global__ void update_step_and_loss_scaler_kernel(
volatile int* overflow_flag,
double* __restrict__ step_and_loss_scaler_vec)
{
// 0 : step
// 1 : iter
// 2 : loss_scale
// 3 : last_overflow_iter
// 4 : scale_factor
// 5 : scale_window
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
double loss_scale = step_and_loss_scaler_vec[2];
double scale_factor = step_and_loss_scaler_vec[4];
int iter = static_cast<int>(step_and_loss_scaler_vec[1]);
int last_overflow_iter = static_cast<int>(step_and_loss_scaler_vec[3]);
if (*overflow_flag == 0) {
// increase step
step_and_loss_scaler_vec[0] += 1.0;
// maybe increase loss scaler
int scale_window = static_cast<int>(step_and_loss_scaler_vec[5]);
if (((iter - last_overflow_iter) % scale_window) == 0) {
step_and_loss_scaler_vec[2] = loss_scale * scale_factor;
}
} else {
step_and_loss_scaler_vec[2] = loss_scale / scale_factor;
step_and_loss_scaler_vec[3] = static_cast<double>(iter);
}
step_and_loss_scaler_vec[1] += 1.0;
}
}
template <int DEPTH, typename FROM_T, typename TO_T>
struct MaybeCastFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl)
{
if (overflow_flag && *overflow_flag != 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];
p_in += chunk_idx*chunk_size;
TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];
p_out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
int dim = chunk_size < n ? chunk_size : n;
FROM_T pi[ILP];
TO_T po[ILP];
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = FROM_T(0);
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
p_out[j] = po[ii];
}
}
}
}
};
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
int dim = chunk_size < n ? chunk_size : n;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
bool overflow = false;
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;
vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
float update = (mi[ii]/denom) + (decay*pi[ii]);
pi[ii] = pi[ii] - (step_size*update);
} else {
overflow = true;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
if (p_copy != NULL) p_copy[j] = static_cast<GRAD_T>(pi[ii]);
}
}
}
if (overflow) {
*overflow_flag = 1;
}
}
};
}
template <int DEPTH, typename T, typename GRAD_T>
struct MaybeAdamUndoFunctor
template <int DEPTH, typename FROM_T, typename TO_T>
struct MaybeCastFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
TensorListMetadata<DEPTH>& tl)
{
// Skip Adam undo when overflow flag is NOT set
if (overflow_flag && *overflow_flag == 0) return;
if (overflow_flag && *overflow_flag != 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];
p_in += chunk_idx*chunk_size;
TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];
p_out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
int dim = chunk_size < n ? chunk_size : n;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
FROM_T pi[ILP];
TO_T po[ILP];
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
pi[ii] = FROM_T(0);
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay);
mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1;
vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f;
}
convert(pi[ii], po[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
p_out[j] = po[ii];
}
}
}
......@@ -696,86 +809,6 @@ void fused_strided_check_finite(
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void fused_reversible_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
......@@ -923,18 +956,6 @@ void maybe_cast_cuda_mt(
THCudaCheck(cudaGetLastError());
}
void update_step_and_loss_scaler_cuda(
at::Tensor & overflow_flag,
at::Tensor & step_and_loss_scaler)
{
AT_ASSERTM(step_and_loss_scaler.numel() == 6, "step_and_loss_scaler must have 6 elements");
AT_ASSERTM(step_and_loss_scaler.scalar_type() == at::ScalarType::Double, "expected step_and_loss_scaler to be a double tensor");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
update_step_and_loss_scaler_kernel<<<1,1,0,stream>>>(
overflow_flag.DATA_PTR<int>(),
step_and_loss_scaler.DATA_PTR<double>());
}
void fused_maybe_adam_undo_cuda(
at::Tensor & overflow_flag,
at::Tensor & p,
......@@ -1013,178 +1034,3 @@ void fused_maybe_adam_undo_cuda(
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
} else {
if (tl_sz == 5) {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
}
THCudaCheck(cudaGetLastError());
}
void fused_maybe_adam_undo_cuda_mt(
int chunk_size,
at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4, "expected tensor list of size 4");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "maybe_adam_undo_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
MaybeAdamUndoFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "maybe_adam_undo_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
MaybeAdamUndoFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment