Commit 4a01ff26 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Partial move towards syncfree optimizer

parent 2622d7f1
#include <torch/extension.h>
// CUDA forward declaration
void fused_strided_check_finite(at::Tensor & noop, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_maybe_adam_undo_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void unpack_e5m2_cuda(at::Tensor & p_in, at::Tensor & p_out);
void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);
void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
void update_step_and_loss_scaler_cuda(at::Tensor & overflow_flag, at::Tensor & step_and_loss_scaler);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
......@@ -18,13 +20,13 @@ void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::
// C++ interface
void strided_check_finite(
at::Tensor& noop,
at::Tensor& overflow_flag,
at::Tensor& p_copy,
int stride,
int clear_overflow_first
) {
CHECK_INPUT(p_copy);
fused_strided_check_finite(noop, p_copy, stride, clear_overflow_first);
fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first);
}
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
......@@ -40,7 +42,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
CHECK_INPUT(m);
CHECK_INPUT(v);
......@@ -50,23 +52,24 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
fused_adam_undo_cuda(p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void unpack_e5m2(at::Tensor & p_in, at::Tensor & p_out) {
void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) {
CHECK_INPUT(p_in);
CHECK_INPUT(p_out);
int64_t num_elem = p_in.numel();
AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal");
unpack_e5m2_cuda(p_in, p_out);
maybe_cast_cuda(overflow_flag, p_in, p_out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("adam_undo_mt", &fused_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2", &unpack_e5m2, "Unpack byte tensor containing e5m2 floats.");
m.def("unpack_e5m2_mt", &unpack_e5m2_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("maybe_adam_undo_mt", &fused_maybe_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
m.def("update_step_and_loss_scaler", &update_step_and_loss_scaler_cuda, "Update step and loss scaler.");
}
......@@ -157,6 +157,51 @@ __global__ void strided_check_finite_cuda_kernel(
}
}
template <typename FROM_T, typename TO_T>
__global__ void maybe_cast_kernel(
volatile int* overflow_flag,
const FROM_T* p_in,
TO_T* p_out,
const size_t tsize)
{
if (overflow_flag && *overflow_flag != 0) return;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
FROM_T pi[ILP];
TO_T po[ILP];
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = 0;
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
p_out[j] = po[ii];
}
}
}
}
template <typename T, typename GRAD_T, typename REDU_T>
__global__ void adam_cuda_kernel(
T* __restrict__ p,
......@@ -243,58 +288,9 @@ __global__ void adam_cuda_kernel(
}
}
template <typename GRAD_T>
__global__ void unpack_e5m2_kernel(
const uint8_t* p_in,
GRAD_T* p_out,
const size_t tsize)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
uint8_t pi[ILP];
GRAD_T po[ILP];
bool overflow = false;
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = 0;
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
if (!isfinite(po[ii])) {
overflow = true;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
p_out[j] = po[ii];
}
}
}
if (overflow) {
p_out[0] = INFINITY;
}
}
template <typename T, typename GRAD_T>
__global__ void adam_undo_cuda_kernel(
__global__ void maybe_adam_undo_cuda_kernel(
volatile int* overflow_flag,
T* __restrict__ p,
T* __restrict__ m,
T* __restrict__ v,
......@@ -308,6 +304,9 @@ __global__ void adam_undo_cuda_kernel(
adamMode_t mode,
const float decay)
{
// NB! Skip undo kernel when overflow flag is NOT set
if (overflow_flag && *overflow_flag == 0) return;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
......@@ -367,15 +366,46 @@ __global__ void adam_undo_cuda_kernel(
}
}
__global__ void update_step_and_loss_scaler_kernel(
volatile int* overflow_flag,
double* __restrict__ step_and_loss_scaler_vec)
{
// 0 : step
// 1 : iter
// 2 : loss_scale
// 3 : last_overflow_iter
// 4 : scale_factor
// 5 : scale_window
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
double loss_scale = step_and_loss_scaler_vec[2];
double scale_factor = step_and_loss_scaler_vec[4];
int iter = static_cast<int>(step_and_loss_scaler_vec[1]);
int last_overflow_iter = static_cast<int>(step_and_loss_scaler_vec[3]);
if (*overflow_flag == 0) {
// increase step
step_and_loss_scaler_vec[0] += 1.0;
// maybe increase loss scaler
int scale_window = static_cast<int>(step_and_loss_scaler_vec[5]);
if (((iter - last_overflow_iter) % scale_window) == 0) {
step_and_loss_scaler_vec[2] = loss_scale * scale_factor;
}
} else {
step_and_loss_scaler_vec[2] = loss_scale / scale_factor;
step_and_loss_scaler_vec[3] = static_cast<double>(iter);
}
step_and_loss_scaler_vec[1] += 1.0;
}
}
template <int DEPTH, typename FROM_T, typename TO_T>
struct UnpackE5M2Functor
struct MaybeCastFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl)
{
if (*noop_gmem != 0) return;
if (overflow_flag && *overflow_flag != 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
......@@ -392,7 +422,6 @@ struct UnpackE5M2Functor
FROM_T pi[ILP];
TO_T po[ILP];
bool overflow = false;
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
......@@ -406,9 +435,6 @@ struct UnpackE5M2Functor
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
if (!isfinite(po[ii])) {
overflow = true;
}
}
#pragma unroll
......@@ -419,10 +445,6 @@ struct UnpackE5M2Functor
}
}
}
if (overflow) {
*noop_gmem = 1;
}
}
};
......@@ -431,7 +453,7 @@ struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
......@@ -516,17 +538,17 @@ struct AdamFunctor
}
if (overflow) {
*noop_gmem = 1;
*overflow_flag = 1;
}
}
};
template <int DEPTH, typename T, typename GRAD_T>
struct AdamUndoFunctor
struct MaybeAdamUndoFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
......@@ -536,6 +558,9 @@ struct AdamUndoFunctor
adamMode_t mode,
const float decay)
{
// Skip Adam undo when overflow flag is NOT set
if (overflow_flag && *overflow_flag == 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
......@@ -606,7 +631,7 @@ struct AdamUndoFunctor
};
void fused_strided_check_finite(
at::Tensor & noop,
at::Tensor & overflow_flag,
at::Tensor & p_copy,
int stride,
int clear_overflow_first)
......@@ -624,7 +649,7 @@ void fused_strided_check_finite(
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
noop.DATA_PTR<int>(),
overflow_flag.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(),
tsize,
stride,
......@@ -734,7 +759,8 @@ void fused_adam_cuda(
THCudaCheck(cudaGetLastError());
}
void unpack_e5m2_cuda(
void maybe_cast_cuda(
at::Tensor & overflow_flag,
at::Tensor & p_in,
at::Tensor & p_out)
{
......@@ -747,20 +773,19 @@ void unpack_e5m2_cuda(
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32");
//Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_ASSERTM(p_in.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
AT_ASSERTM(p_out.scalar_type() == at::ScalarType::Half, "expected parameter to be of half type");
DISPATCH_FLOAT_AND_HALF(p_out.scalar_type(), 0, "unpack_e5m2",
unpack_e5m2_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<uint8_t>(),
p_out.DATA_PTR<scalar_t_0>(),
tsize);
);
DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, "maybe_cast_cuda"
DISPATCH_FLOAT_HALF_AND_BYTE(p_out.scalar_type(), 1, "maybe_cast_cuda",
maybe_cast_kernel<scalar_t_0,scalar_t_1><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p_in.DATA_PTR<scalar_t_0>(),
p_out.DATA_PTR<scalar_t_1>(),
tsize); ))
THCudaCheck(cudaGetLastError());
}
void unpack_e5m2_cuda_mt(
void maybe_cast_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) // p_in, p_out
{
//Constants
......@@ -769,18 +794,31 @@ void unpack_e5m2_cuda_mt(
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 2, "expected tensor lists of size 2");
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 0, "unpack_e5m2_cuda_mt_kernel",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
UnpackE5M2Functor<2, uint8_t, scalar_t_0>());
);
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[0][0].scalar_type(), 0, "maybe_cast_cuda_mt_kernel",
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, "maybe_cast_cuda_mt_kernel",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); ))
THCudaCheck(cudaGetLastError());
}
void fused_adam_undo_cuda(
void update_step_and_loss_scaler_cuda(
at::Tensor & overflow_flag,
at::Tensor & step_and_loss_scaler)
{
AT_ASSERTM(step_and_loss_scaler.numel() == 6, "step_and_loss_scaler must have 6 elements");
AT_ASSERTM(step_and_loss_scaler.scalar_type() == at::ScalarType::Double, "expected step_and_loss_scaler to be a double tensor");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
update_step_and_loss_scaler_kernel<<<1,1,0,stream>>>(
overflow_flag.DATA_PTR<int>(),
step_and_loss_scaler.DATA_PTR<double>());
}
void fused_maybe_adam_undo_cuda(
at::Tensor & overflow_flag,
at::Tensor & p,
at::Tensor & m,
at::Tensor & v,
......@@ -795,72 +833,71 @@ void fused_adam_undo_cuda(
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p.DATA_PTR<accscalar_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
maybe_adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p.DATA_PTR<scalar_t_0>(),
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
......@@ -897,7 +934,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
overflow_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
......@@ -914,7 +951,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
overflow_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
......@@ -932,7 +969,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
overflow_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
......@@ -948,7 +985,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
overflow_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
......@@ -964,9 +1001,9 @@ void fused_adam_cuda_mt(
THCudaCheck(cudaGetLastError());
}
void fused_adam_undo_cuda_mt(
void fused_maybe_adam_undo_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
......@@ -997,14 +1034,14 @@ void fused_adam_undo_cuda_mt(
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_undo_cuda_mt_kernel",
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "maybe_adam_undo_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
overflow_flag,
tensor_lists,
AdamUndoFunctor<4, accscalar_t, scalar_t_0>(),
MaybeAdamUndoFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -1014,13 +1051,13 @@ void fused_adam_undo_cuda_mt(
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_undo_cuda_mt_kernel",
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "maybe_adam_undo_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
overflow_flag,
tensor_lists,
AdamUndoFunctor<4, scalar_t_0, scalar_t_0>(),
MaybeAdamUndoFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
......
......@@ -154,6 +154,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
......@@ -166,6 +168,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
......@@ -180,6 +183,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream()
......@@ -452,7 +457,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
beta1, beta2 = group['betas']
if undo:
if self._revert_method == 1:
fused_adam_cuda.adam_undo(
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p[group_buffer_start:group_buffer_end],
self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end],
......@@ -576,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
param_i += 1
if self._e5m2_allgather:
multi_tensor_applier(
fused_adam_cuda.unpack_e5m2_mt,
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
[p_in, p_out]);
elif self._do_not_flatten_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment