Commit 25c80afe authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Re-introduce original non-reversible fused contrib adam cuda kernel

parent 9bb71066
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first); void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
...@@ -42,6 +43,20 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -42,6 +43,20 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
} }
void reversible_adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p); CHECK_INPUT(p);
CHECK_INPUT(m); CHECK_INPUT(m);
...@@ -66,6 +81,7 @@ void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_ou ...@@ -66,6 +81,7 @@ void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_ou
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check."); m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation."); m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation."); m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("maybe_adam_undo_mt", &fused_maybe_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation."); m.def("maybe_adam_undo_mt", &fused_maybe_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
......
...@@ -202,8 +202,46 @@ __global__ void maybe_cast_kernel( ...@@ -202,8 +202,46 @@ __global__ void maybe_cast_kernel(
} }
} }
template <typename T, typename GRAD_T, typename REDU_T> template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel( __global__ void adam_cuda_kernel(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*p[j]);
p[j] = p[j] - (step_size*update);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}
template <typename T, typename GRAD_T, typename REDU_T>
__global__ void reversible_adam_cuda_kernel(
T* __restrict__ p, T* __restrict__ p,
REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m, T* __restrict__ m,
...@@ -674,6 +712,86 @@ void fused_adam_cuda( ...@@ -674,6 +712,86 @@ void fused_adam_cuda(
int bias_correction, int bias_correction,
float decay) float decay)
{ {
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void fused_reversible_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at; // using namespace at;
//Get tensor size //Get tensor size
...@@ -702,7 +820,7 @@ void fused_adam_cuda( ...@@ -702,7 +820,7 @@ void fused_adam_cuda(
if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) { if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL, p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(), m.DATA_PTR<accscalar_t>(),
...@@ -721,7 +839,7 @@ void fused_adam_cuda( ...@@ -721,7 +839,7 @@ void fused_adam_cuda(
AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel", DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>( reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
p_copy.DATA_PTR<uint8_t>(), p_copy.DATA_PTR<uint8_t>(),
m.DATA_PTR<accscalar_t>(), m.DATA_PTR<accscalar_t>(),
...@@ -740,7 +858,7 @@ void fused_adam_cuda( ...@@ -740,7 +858,7 @@ void fused_adam_cuda(
} else { } else {
using namespace at; using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( reversible_adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(), p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(), m.DATA_PTR<scalar_t_0>(),
......
...@@ -360,7 +360,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -360,7 +360,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
combined_scale = self._global_scale / min(1, combined_scale) combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0 bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas'] beta1, beta2 = self._param_group['betas']
fused_adam_cuda.adam( fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g, p, p_copy, m, v, g,
self._param_group['lr'], self._param_group['lr'],
beta1, beta1,
......
...@@ -413,7 +413,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -413,7 +413,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
combined_scale = self._global_scale / min(1, combined_scale) combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0 bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas'] beta1, beta2 = self._param_group['betas']
fused_adam_cuda.adam( fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g, p, p_copy, m, v, g,
self._param_group['lr'], self._param_group['lr'],
beta1, beta1,
......
...@@ -227,7 +227,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -227,7 +227,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
combined_scale = self._global_scale / min(1, combined_scale) combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0 bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas'] beta1, beta2 = self._param_group['betas']
fused_adam_cuda.adam( fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g, p, p_copy, m, v, g,
self._param_group['lr'], self._param_group['lr'],
beta1, beta1,
...@@ -325,6 +325,8 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -325,6 +325,8 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
for p in self._model_params: self.state[p]['step'] += 1 for p in self._model_params: self.state[p]['step'] += 1
torch.cuda.current_stream().wait_stream(self._dwu_st) torch.cuda.current_stream().wait_stream(self._dwu_st)
else:
print("Overflow detected, skipping step")
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment