Commit 029cd5e1 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Avoid 1 copy for double buffering scheme

parent b85ff391
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
// CUDA forward declaration // CUDA forward declaration
void fused_strided_check_finite(at::Tensor & noop, at::Tensor & p_copy, int stride, int clear_overflow_first); void fused_strided_check_finite(at::Tensor & noop, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_adam_cuda_no_overflow_check(at::Tensor & p_in, Tensor & p_out, at::Tensor & p_copy, at::Tensor & m_in, at::Tensor & m_out, at::Tensor & v_in, at::Tensor & v_out, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
...@@ -24,6 +25,26 @@ void strided_check_finite( ...@@ -24,6 +25,26 @@ void strided_check_finite(
CHECK_INPUT(p_copy); CHECK_INPUT(p_copy);
fused_strided_check_finite(noop, p_copy, stride, clear_overflow_first); fused_strided_check_finite(noop, p_copy, stride, clear_overflow_first);
} }
void adam_no_overflow_chedck(at::Tensor & p_in, at::Tensor & p_out, at::Tensor & p_copy, at::Tensor & m_in, at::Tensor & m_out, at::Tensor & v_in, at::Tensor & v_out, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p_in);
CHECK_INPUT(p_out);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m_in);
CHECK_INPUT(m_out);
CHECK_INPUT(v_in);
CHECK_INPUT(v_out);
CHECK_INPUT(g);
int64_t num_elem = p_in.numel();
AT_ASSERTM(m_in.numel() == num_elem, "number of elements in m_in and p_in tensors should be equal");
AT_ASSERTM(m_out.numel() == num_elem, "number of elements in m_out and p_in tensors should be equal");
AT_ASSERTM(v_in.numel() == num_elem, "number of elements in v_in and p_in tensors should be equal");
AT_ASSERTM(v_out.numel() == num_elem, "number of elements in v_out and p_in tensors should be equal");
AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_out and p_in tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_adam_cuda_no_overflow_check(p_in, p_out, p_copy, m_in, m_out, v_in, v_out, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p); CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy); if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
...@@ -53,6 +74,7 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f ...@@ -53,6 +74,7 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check."); m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam_no_overflow_check", &adam, "Adam optimized CUDA implementation.");
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation."); m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation."); m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
......
...@@ -51,6 +51,81 @@ __global__ void strided_check_finite_cuda_kernel( ...@@ -51,6 +51,81 @@ __global__ void strided_check_finite_cuda_kernel(
} }
} }
template <typename T, typename GRAD_T>
__global__ void adam_cuda_no_overflow_check_kernel(
T* __restrict__ p_in,
T* __restrict__ p_out,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m_in,
T* __restrict__ m_out,
T* __restrict__ v_in,
T* __restrict__ v_out,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
mi[ii] = m_in[j];
vi[ii] = v_in[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
T scaled_grad = gi[ii]/grad_scale;
mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;
vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
float update = (mi[ii]/denom) + (decay*pi[ii]);
pi[ii] = pi[ii] - (step_size*update);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
m_out[j] = mi[ii];
v_out[j] = vi[ii];
p_out[j] = pi[ii];
if (p_copy != NULL) p_copy[j] = static_cast<GRAD_T>(pi[ii]);
}
}
}
}
template <typename T, typename GRAD_T> template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel( __global__ void adam_cuda_kernel(
T* __restrict__ p, T* __restrict__ p,
...@@ -416,6 +491,95 @@ void fused_strided_check_finite( ...@@ -416,6 +491,95 @@ void fused_strided_check_finite(
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void fused_adam_cuda_no_overflow_check(
at::Tensor & p_in,
at::Tensor & p_out,
at::Tensor & p_copy,
at::Tensor & m_in,
at::Tensor & m_out,
at::Tensor & v_in,
at::Tensor & v_out,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel_no_overflow_check<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<accscalar_t>(),
p_out.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m_in.DATA_PTR<accscalar_t>(),
m_out.DATA_PTR<accscalar_t>(),
v_in.DATA_PTR<accscalar_t>(),
v_out.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel_no_overflow_check<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<accscalar_t>(),
p_out.DATA_PTR<accscalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
m_in.DATA_PTR<accscalar_t>(),
m_out.DATA_PTR<accscalar_t>(),
v_in.DATA_PTR<accscalar_t>(),
v_out.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda( void fused_adam_cuda(
at::Tensor & p, at::Tensor & p,
at::Tensor & p_copy, at::Tensor & p_copy,
......
...@@ -292,6 +292,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -292,6 +292,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def L2_grad_norm(self): def L2_grad_norm(self):
return self._L2_grad_norm return self._L2_grad_norm
def __swap_optimizer_state_buffers(self):
p,m,v = self._fp32_p,self._fp32_m,self._fp32_v
self._fp32_p,self._fp32_m,self._fp32_v = self._fp32_backup_p,self._fp32_backup_m,self._fp32_backup_v
self._fp32_backup_p,self._fp32_backup_m,self._fp32_backup_v = p,m,v
# Distributed weight update algorithm: # Distributed weight update algorithm:
# Model parameters are kept as-is. # Model parameters are kept as-is.
# Gradients are flattened during backprop. # Gradients are flattened during backprop.
...@@ -404,16 +409,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -404,16 +409,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
elif self._revert_method == 2: elif self._revert_method == 2:
self._fp32_p[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_p[group_buffer_start:group_buffer_end]) self._swap_optimizer_state_buffers()
self._fp32_m[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_m[group_buffer_start:group_buffer_end])
self._fp32_v[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_v[group_buffer_start:group_buffer_end])
elif self._revert_method == 3: elif self._revert_method == 3:
raise RuntimeError('revert_step debug option not implemented yet') raise RuntimeError('revert_step debug option not implemented yet')
else: else:
if self._revert_method > 1: if self._revert_method == 1:
self._fp32_backup_p[group_buffer_start:group_buffer_end].copy_(self._fp32_p[group_buffer_start:group_buffer_end])
self._fp32_backup_m[group_buffer_start:group_buffer_end].copy_(self._fp32_m[group_buffer_start:group_buffer_end])
self._fp32_backup_v[group_buffer_start:group_buffer_end].copy_(self._fp32_v[group_buffer_start:group_buffer_end])
fused_adam_cuda.adam( fused_adam_cuda.adam(
self._fp32_p[group_buffer_start:group_buffer_end], self._fp32_p[group_buffer_start:group_buffer_end],
self._new_params[group_shard_start:group_shard_end], self._new_params[group_shard_start:group_shard_end],
...@@ -429,6 +429,28 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -429,6 +429,28 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.eps_mode, self.eps_mode,
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
elif self._revert_method == 2:
self._swap_optimizer_state_buffers()
fused_adam_cuda.adam_no_overflow_check(
self._fp32_backup_p[group_buffer_start:group_buffer_end],
self._fp32_p[group_buffer_start:group_buffer_end],
self._new_params[group_shard_start:group_shard_end],
self._fp32_backup_m[group_buffer_start:group_buffer_end],
self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_backup_v[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end],
self._flat_grads[group_shard_start:group_shard_end],
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step+1,
self.eps_mode,
bias_correction,
group['weight_decay'])
elif self._revert_method == 3:
raise RuntimeError('revert_step debug option not implemented yet')
def _do_compute_L2_grad_norm(self): def _do_compute_L2_grad_norm(self):
partial_sum = torch.zeros([]).cuda() partial_sum = torch.zeros([]).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment