Commit 4a01ff26 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Partial move towards syncfree optimizer

parent 2622d7f1
#include <torch/extension.h> #include <torch/extension.h>
// CUDA forward declaration // CUDA forward declaration
void fused_strided_check_finite(at::Tensor & noop, at::Tensor & p_copy, int stride, int clear_overflow_first); void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_maybe_adam_undo_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void unpack_e5m2_cuda(at::Tensor & p_in, at::Tensor & p_out); void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);
void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists); void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
void update_step_and_loss_scaler_cuda(at::Tensor & overflow_flag, at::Tensor & step_and_loss_scaler);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
...@@ -18,13 +20,13 @@ void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std:: ...@@ -18,13 +20,13 @@ void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::
// C++ interface // C++ interface
void strided_check_finite( void strided_check_finite(
at::Tensor& noop, at::Tensor& overflow_flag,
at::Tensor& p_copy, at::Tensor& p_copy,
int stride, int stride,
int clear_overflow_first int clear_overflow_first
) { ) {
CHECK_INPUT(p_copy); CHECK_INPUT(p_copy);
fused_strided_check_finite(noop, p_copy, stride, clear_overflow_first); fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first);
} }
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p); CHECK_INPUT(p);
...@@ -40,7 +42,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -40,7 +42,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
} }
void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p); CHECK_INPUT(p);
CHECK_INPUT(m); CHECK_INPUT(m);
CHECK_INPUT(v); CHECK_INPUT(v);
...@@ -50,23 +52,24 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f ...@@ -50,23 +52,24 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
fused_adam_undo_cuda(p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
} }
void unpack_e5m2(at::Tensor & p_in, at::Tensor & p_out) { void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) {
CHECK_INPUT(p_in); CHECK_INPUT(p_in);
CHECK_INPUT(p_out); CHECK_INPUT(p_out);
int64_t num_elem = p_in.numel(); int64_t num_elem = p_in.numel();
AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal"); AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal");
unpack_e5m2_cuda(p_in, p_out); maybe_cast_cuda(overflow_flag, p_in, p_out);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check."); m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation."); m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("adam_undo_mt", &fused_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation."); m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2", &unpack_e5m2, "Unpack byte tensor containing e5m2 floats."); m.def("maybe_adam_undo_mt", &fused_maybe_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2_mt", &unpack_e5m2_cuda_mt, "Unpack byte tensor containing e5m2 floats."); m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
m.def("update_step_and_loss_scaler", &update_step_and_loss_scaler_cuda, "Update step and loss scaler.");
} }
...@@ -157,6 +157,51 @@ __global__ void strided_check_finite_cuda_kernel( ...@@ -157,6 +157,51 @@ __global__ void strided_check_finite_cuda_kernel(
} }
} }
template <typename FROM_T, typename TO_T>
__global__ void maybe_cast_kernel(
volatile int* overflow_flag,
const FROM_T* p_in,
TO_T* p_out,
const size_t tsize)
{
if (overflow_flag && *overflow_flag != 0) return;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
FROM_T pi[ILP];
TO_T po[ILP];
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = 0;
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
p_out[j] = po[ii];
}
}
}
}
template <typename T, typename GRAD_T, typename REDU_T> template <typename T, typename GRAD_T, typename REDU_T>
__global__ void adam_cuda_kernel( __global__ void adam_cuda_kernel(
T* __restrict__ p, T* __restrict__ p,
...@@ -243,58 +288,9 @@ __global__ void adam_cuda_kernel( ...@@ -243,58 +288,9 @@ __global__ void adam_cuda_kernel(
} }
} }
template <typename GRAD_T>
__global__ void unpack_e5m2_kernel(
const uint8_t* p_in,
GRAD_T* p_out,
const size_t tsize)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
uint8_t pi[ILP];
GRAD_T po[ILP];
bool overflow = false;
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = 0;
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
if (!isfinite(po[ii])) {
overflow = true;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
p_out[j] = po[ii];
}
}
}
if (overflow) {
p_out[0] = INFINITY;
}
}
template <typename T, typename GRAD_T> template <typename T, typename GRAD_T>
__global__ void adam_undo_cuda_kernel( __global__ void maybe_adam_undo_cuda_kernel(
volatile int* overflow_flag,
T* __restrict__ p, T* __restrict__ p,
T* __restrict__ m, T* __restrict__ m,
T* __restrict__ v, T* __restrict__ v,
...@@ -308,6 +304,9 @@ __global__ void adam_undo_cuda_kernel( ...@@ -308,6 +304,9 @@ __global__ void adam_undo_cuda_kernel(
adamMode_t mode, adamMode_t mode,
const float decay) const float decay)
{ {
// NB! Skip undo kernel when overflow flag is NOT set
if (overflow_flag && *overflow_flag == 0) return;
//Assuming 2D grids and 2D blocks //Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y; const int threadsPerBlock = blockDim.x * blockDim.y;
...@@ -367,15 +366,46 @@ __global__ void adam_undo_cuda_kernel( ...@@ -367,15 +366,46 @@ __global__ void adam_undo_cuda_kernel(
} }
} }
__global__ void update_step_and_loss_scaler_kernel(
volatile int* overflow_flag,
double* __restrict__ step_and_loss_scaler_vec)
{
// 0 : step
// 1 : iter
// 2 : loss_scale
// 3 : last_overflow_iter
// 4 : scale_factor
// 5 : scale_window
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
double loss_scale = step_and_loss_scaler_vec[2];
double scale_factor = step_and_loss_scaler_vec[4];
int iter = static_cast<int>(step_and_loss_scaler_vec[1]);
int last_overflow_iter = static_cast<int>(step_and_loss_scaler_vec[3]);
if (*overflow_flag == 0) {
// increase step
step_and_loss_scaler_vec[0] += 1.0;
// maybe increase loss scaler
int scale_window = static_cast<int>(step_and_loss_scaler_vec[5]);
if (((iter - last_overflow_iter) % scale_window) == 0) {
step_and_loss_scaler_vec[2] = loss_scale * scale_factor;
}
} else {
step_and_loss_scaler_vec[2] = loss_scale / scale_factor;
step_and_loss_scaler_vec[3] = static_cast<double>(iter);
}
step_and_loss_scaler_vec[1] += 1.0;
}
}
template <int DEPTH, typename FROM_T, typename TO_T> template <int DEPTH, typename FROM_T, typename TO_T>
struct UnpackE5M2Functor struct MaybeCastFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl) TensorListMetadata<DEPTH>& tl)
{ {
if (*noop_gmem != 0) return; if (overflow_flag && *overflow_flag != 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
...@@ -392,7 +422,6 @@ struct UnpackE5M2Functor ...@@ -392,7 +422,6 @@ struct UnpackE5M2Functor
FROM_T pi[ILP]; FROM_T pi[ILP];
TO_T po[ILP]; TO_T po[ILP];
bool overflow = false;
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) { for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
...@@ -406,9 +435,6 @@ struct UnpackE5M2Functor ...@@ -406,9 +435,6 @@ struct UnpackE5M2Functor
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]); convert(pi[ii], po[ii]);
if (!isfinite(po[ii])) {
overflow = true;
}
} }
#pragma unroll #pragma unroll
...@@ -419,10 +445,6 @@ struct UnpackE5M2Functor ...@@ -419,10 +445,6 @@ struct UnpackE5M2Functor
} }
} }
} }
if (overflow) {
*noop_gmem = 1;
}
} }
}; };
...@@ -431,7 +453,7 @@ struct AdamFunctor ...@@ -431,7 +453,7 @@ struct AdamFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl, TensorListMetadata<DEPTH>& tl,
const float b1, const float b1,
const float b2, const float b2,
...@@ -516,17 +538,17 @@ struct AdamFunctor ...@@ -516,17 +538,17 @@ struct AdamFunctor
} }
if (overflow) { if (overflow) {
*noop_gmem = 1; *overflow_flag = 1;
} }
} }
}; };
template <int DEPTH, typename T, typename GRAD_T> template <int DEPTH, typename T, typename GRAD_T>
struct AdamUndoFunctor struct MaybeAdamUndoFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl, TensorListMetadata<DEPTH>& tl,
const float b1, const float b1,
const float b2, const float b2,
...@@ -536,6 +558,9 @@ struct AdamUndoFunctor ...@@ -536,6 +558,9 @@ struct AdamUndoFunctor
adamMode_t mode, adamMode_t mode,
const float decay) const float decay)
{ {
// Skip Adam undo when overflow flag is NOT set
if (overflow_flag && *overflow_flag == 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
...@@ -606,7 +631,7 @@ struct AdamUndoFunctor ...@@ -606,7 +631,7 @@ struct AdamUndoFunctor
}; };
void fused_strided_check_finite( void fused_strided_check_finite(
at::Tensor & noop, at::Tensor & overflow_flag,
at::Tensor & p_copy, at::Tensor & p_copy,
int stride, int stride,
int clear_overflow_first) int clear_overflow_first)
...@@ -624,7 +649,7 @@ void fused_strided_check_finite( ...@@ -624,7 +649,7 @@ void fused_strided_check_finite(
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, "check_finite_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
noop.DATA_PTR<int>(), overflow_flag.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(), p_copy.DATA_PTR<scalar_t_0>(),
tsize, tsize,
stride, stride,
...@@ -734,7 +759,8 @@ void fused_adam_cuda( ...@@ -734,7 +759,8 @@ void fused_adam_cuda(
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void unpack_e5m2_cuda( void maybe_cast_cuda(
at::Tensor & overflow_flag,
at::Tensor & p_in, at::Tensor & p_in,
at::Tensor & p_out) at::Tensor & p_out)
{ {
...@@ -747,20 +773,19 @@ void unpack_e5m2_cuda( ...@@ -747,20 +773,19 @@ void unpack_e5m2_cuda(
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32"); AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32");
//Constants //Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_ASSERTM(p_in.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, "maybe_cast_cuda"
AT_ASSERTM(p_out.scalar_type() == at::ScalarType::Half, "expected parameter to be of half type"); DISPATCH_FLOAT_HALF_AND_BYTE(p_out.scalar_type(), 1, "maybe_cast_cuda",
DISPATCH_FLOAT_AND_HALF(p_out.scalar_type(), 0, "unpack_e5m2", maybe_cast_kernel<scalar_t_0,scalar_t_1><<<blocks,threadsPerBlock, 0, stream>>>(
unpack_e5m2_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p_in.DATA_PTR<uint8_t>(), p_in.DATA_PTR<scalar_t_0>(),
p_out.DATA_PTR<scalar_t_0>(), p_out.DATA_PTR<scalar_t_1>(),
tsize); tsize); ))
);
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void unpack_e5m2_cuda_mt( void maybe_cast_cuda_mt(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) // p_in, p_out std::vector<std::vector<at::Tensor>> tensor_lists) // p_in, p_out
{ {
//Constants //Constants
...@@ -769,18 +794,31 @@ void unpack_e5m2_cuda_mt( ...@@ -769,18 +794,31 @@ void unpack_e5m2_cuda_mt(
size_t tl_sz = tensor_lists.size(); size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 2, "expected tensor lists of size 2"); AT_ASSERTM(tl_sz == 2, "expected tensor lists of size 2");
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 0, "unpack_e5m2_cuda_mt_kernel", DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[0][0].scalar_type(), 0, "maybe_cast_cuda_mt_kernel",
multi_tensor_apply<2>( DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, "maybe_cast_cuda_mt_kernel",
BLOCK_SIZE, multi_tensor_apply<2>(
chunk_size, BLOCK_SIZE,
noop_flag, chunk_size,
tensor_lists, overflow_flag,
UnpackE5M2Functor<2, uint8_t, scalar_t_0>()); tensor_lists,
); MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); ))
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void fused_adam_undo_cuda( void update_step_and_loss_scaler_cuda(
at::Tensor & overflow_flag,
at::Tensor & step_and_loss_scaler)
{
AT_ASSERTM(step_and_loss_scaler.numel() == 6, "step_and_loss_scaler must have 6 elements");
AT_ASSERTM(step_and_loss_scaler.scalar_type() == at::ScalarType::Double, "expected step_and_loss_scaler to be a double tensor");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
update_step_and_loss_scaler_kernel<<<1,1,0,stream>>>(
overflow_flag.DATA_PTR<int>(),
step_and_loss_scaler.DATA_PTR<double>());
}
void fused_maybe_adam_undo_cuda(
at::Tensor & overflow_flag,
at::Tensor & p, at::Tensor & p,
at::Tensor & m, at::Tensor & m,
at::Tensor & v, at::Tensor & v,
...@@ -795,72 +833,71 @@ void fused_adam_undo_cuda( ...@@ -795,72 +833,71 @@ void fused_adam_undo_cuda(
int bias_correction, int bias_correction,
float decay) float decay)
{ {
// using namespace at; //Get tensor size
int tsize = p.numel();
//Get tensor size //Determine #threads and #blocks
int tsize = p.numel(); const int threadsPerBlock = 512;
//Determine #threads and #blocks const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
const int threadsPerBlock = 512; AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); //Constants
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); float step_size = 0;
//Constants if (bias_correction == 1) {
float step_size = 0; const float bias_correction1 = 1 - std::pow(beta1, step);
if (bias_correction == 1) { const float bias_correction2 = 1 - std::pow(beta2, step);
const float bias_correction1 = 1 - std::pow(beta1, step); step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
const float bias_correction2 = 1 - std::pow(beta2, step); }
step_size = lr * std::sqrt(bias_correction2)/bias_correction1; else {
} step_size = lr;
else { }
step_size = lr; cudaStream_t stream = at::cuda::getCurrentCUDAStream();
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
if (g.scalar_type() == at::ScalarType::Half) { AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//all other values should be fp32 for half gradients //dispatch is done on the gradient type
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); using namespace at; // prevents "toString is undefined" errors
//dispatch is done on the gradient type DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
m.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(), m.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(), v.DATA_PTR<accscalar_t>(),
beta1, g.DATA_PTR<scalar_t_0>(),
beta2, beta1,
eps, beta2,
grad_scale, eps,
step_size, grad_scale,
tsize, step_size,
(adamMode_t) mode, tsize,
decay); (adamMode_t) mode,
decay);
); );
} else { } else {
using namespace at; using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( maybe_adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(), overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
m.DATA_PTR<scalar_t_0>(), p.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(), m.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(), v.DATA_PTR<scalar_t_0>(),
beta1, g.DATA_PTR<scalar_t_0>(),
beta2, beta1,
eps, beta2,
grad_scale, eps,
step_size, grad_scale,
tsize, step_size,
(adamMode_t) mode, tsize,
decay); (adamMode_t) mode,
); decay);
} );
THCudaCheck(cudaGetLastError()); }
THCudaCheck(cudaGetLastError());
} }
void fused_adam_cuda_mt( void fused_adam_cuda_mt(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr, float lr,
float beta1, float beta1,
...@@ -897,7 +934,7 @@ void fused_adam_cuda_mt( ...@@ -897,7 +934,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, overflow_flag,
tensor_lists, tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(), AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1, beta1,
...@@ -914,7 +951,7 @@ void fused_adam_cuda_mt( ...@@ -914,7 +951,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, overflow_flag,
tensor_lists, tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(), AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1, beta1,
...@@ -932,7 +969,7 @@ void fused_adam_cuda_mt( ...@@ -932,7 +969,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, overflow_flag,
tensor_lists, tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(), AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1, beta1,
...@@ -948,7 +985,7 @@ void fused_adam_cuda_mt( ...@@ -948,7 +985,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, overflow_flag,
tensor_lists, tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(), AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1, beta1,
...@@ -964,9 +1001,9 @@ void fused_adam_cuda_mt( ...@@ -964,9 +1001,9 @@ void fused_adam_cuda_mt(
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void fused_adam_undo_cuda_mt( void fused_maybe_adam_undo_cuda_mt(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr, float lr,
float beta1, float beta1,
...@@ -997,14 +1034,14 @@ void fused_adam_undo_cuda_mt( ...@@ -997,14 +1034,14 @@ void fused_adam_undo_cuda_mt(
//alher values should be fp32 for half gradients //alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type //dich is done on the gradient type
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_undo_cuda_mt_kernel", DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "maybe_adam_undo_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, overflow_flag,
tensor_lists, tensor_lists,
AdamUndoFunctor<4, accscalar_t, scalar_t_0>(), MaybeAdamUndoFunctor<4, accscalar_t, scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -1014,13 +1051,13 @@ void fused_adam_undo_cuda_mt( ...@@ -1014,13 +1051,13 @@ void fused_adam_undo_cuda_mt(
decay); decay);
); );
} else { } else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_undo_cuda_mt_kernel", DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "maybe_adam_undo_cuda_mt_kernel",
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, overflow_flag,
tensor_lists, tensor_lists,
AdamUndoFunctor<4, scalar_t_0, scalar_t_0>(), MaybeAdamUndoFunctor<4, scalar_t_0, scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
......
...@@ -154,6 +154,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -154,6 +154,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp) self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)] self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = [] rs_ranks = []
for group_i in range(self._num_groups): for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
...@@ -166,6 +168,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -166,6 +168,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._rs_pg.append(grp) self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks: if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)] self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
if self._num_ag_pg == 0: if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg self._ag_pg = self._rs_pg
...@@ -180,6 +183,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -180,6 +183,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp) self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)] self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream() self._completion_st = torch.cuda.Stream()
...@@ -452,7 +457,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -452,7 +457,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
if undo: if undo:
if self._revert_method == 1: if self._revert_method == 1:
fused_adam_cuda.adam_undo( fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p[group_buffer_start:group_buffer_end], self._fp32_p[group_buffer_start:group_buffer_end],
self._fp32_m[group_buffer_start:group_buffer_end], self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end], self._fp32_v[group_buffer_start:group_buffer_end],
...@@ -576,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -576,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
param_i += 1 param_i += 1
if self._e5m2_allgather: if self._e5m2_allgather:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.unpack_e5m2_mt, fused_adam_cuda.maybe_cast_mt,
self._overflow_buf, self._overflow_buf,
[p_in, p_out]); [p_in, p_out]);
elif self._do_not_flatten_model: elif self._do_not_flatten_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment