Commit cd206434 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add e5m2 allgather option

parent aa90d31f
...@@ -9,6 +9,7 @@ void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Te ...@@ -9,6 +9,7 @@ void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Te
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_undo_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void unpack_e5m2_cuda(at::Tensor & p_in, at::Tensor & p_out);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
...@@ -50,11 +51,20 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f ...@@ -50,11 +51,20 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
fused_adam_undo_cuda(p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); fused_adam_undo_cuda(p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
} }
void unpack_e5m2(at::Tensor & p_in, at::Tensor & p_out) {
CHECK_INPUT(p_in);
CHECK_INPUT(p_out);
int64_t num_elem = p_in.numel();
AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal");
unpack_e5m2_cuda(p_in, p_out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check."); m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation."); m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2", &unpack_e5m2, "Unpack byte tensor containing e5m2 floats.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation."); m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("adam_undo_mt", &fused_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation."); m.def("adam_undo_mt", &fused_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
} }
...@@ -21,6 +21,82 @@ typedef enum{ ...@@ -21,6 +21,82 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root ADAM_MODE_1 =1 // eps outside square root
} adamMode_t; } adamMode_t;
template <typename FROM_T, typename TO_T>
__device__ void convert(const FROM_T vi, TO_T& vo)
{
vo = static_cast<TO_T>(vi);
}
template <>
__device__ void convert(const float vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = vi;
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, float& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = static_cast<float>(t.as_half);
}
template <>
__device__ void convert(const at::Half vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = static_cast<float>(vi);
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, at::Half& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = t.as_half;
}
template <typename GRAD_T> template <typename GRAD_T>
__global__ void strided_check_finite_cuda_kernel( __global__ void strided_check_finite_cuda_kernel(
volatile int* noop_gmem, volatile int* noop_gmem,
...@@ -44,17 +120,17 @@ __global__ void strided_check_finite_cuda_kernel( ...@@ -44,17 +120,17 @@ __global__ void strided_check_finite_cuda_kernel(
} }
for (int j = i; j < tsize; j+=totThreads) { for (int j = i; j < tsize; j+=totThreads) {
GRAD_T pi = p_copy[j]; GRAD_T pi = p_copy[j];
if (!isfinite(pi)) { if (!isfinite(pi)) {
*noop_gmem = 1; *noop_gmem = 1;
} }
} }
} }
template <typename T, typename GRAD_T> template <typename T, typename GRAD_T, typename REDU_T>
__global__ void adam_cuda_kernel( __global__ void adam_cuda_kernel(
T* __restrict__ p, T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m, T* __restrict__ m,
T* __restrict__ v, T* __restrict__ v,
const GRAD_T * __restrict__ g, const GRAD_T * __restrict__ g,
...@@ -122,7 +198,9 @@ __global__ void adam_cuda_kernel( ...@@ -122,7 +198,9 @@ __global__ void adam_cuda_kernel(
m[j] = mi[ii]; m[j] = mi[ii];
v[j] = vi[ii]; v[j] = vi[ii];
p[j] = pi[ii]; p[j] = pi[ii];
if (p_copy != NULL) p_copy[j] = static_cast<GRAD_T>(pi[ii]); if (p_copy != NULL) {
convert(pi[ii], p_copy[j]);
}
} }
} }
} }
...@@ -130,9 +208,59 @@ __global__ void adam_cuda_kernel( ...@@ -130,9 +208,59 @@ __global__ void adam_cuda_kernel(
if (p_copy != NULL) { if (p_copy != NULL) {
__syncthreads(); __syncthreads();
if (overflow) { if (overflow) {
p_copy[0] = INFINITY; convert(float(INFINITY), p_copy[0]);
}
}
}
template <typename GRAD_T>
__global__ void unpack_e5m2_kernel(
const uint8_t* p_in,
GRAD_T* p_out,
const size_t tsize)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
uint8_t pi[ILP];
GRAD_T po[ILP];
bool overflow = false;
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = 0;
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
if (!isfinite(po[ii])) {
overflow = true;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
p_out[j] = po[ii];
}
} }
} }
if (overflow) {
p_out[0] = INFINITY;
}
} }
template <typename T, typename GRAD_T> template <typename T, typename GRAD_T>
...@@ -404,15 +532,15 @@ void fused_strided_check_finite( ...@@ -404,15 +532,15 @@ void fused_strided_check_finite(
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32"); AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32");
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(p_copy.scalar_type(), 0, "check_finite_cuda_kernel", DISPATCH_FLOAT_AND_HALF(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
noop.DATA_PTR<int>(), noop.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(), p_copy.DATA_PTR<scalar_t_0>(),
tsize, tsize,
stride, stride,
clear_overflow_first); clear_overflow_first);
); );
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
...@@ -432,69 +560,113 @@ void fused_adam_cuda( ...@@ -432,69 +560,113 @@ void fused_adam_cuda(
int bias_correction, int bias_correction,
float decay) float decay)
{ {
// using namespace at; // using namespace at;
//Get tensor size //Get tensor size
int tsize = p.numel(); int tsize = p.numel();
//Determine #threads and #blocks //Determine #threads and #blocks
const int threadsPerBlock = 512; const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants //Constants
float step_size = 0; float step_size = 0;
if (bias_correction == 1) { if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step); const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step); const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1; step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
} }
else { else {
step_size = lr; step_size = lr;
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {
using accscalar_t = at::acc_type<scalar_t_0, true>; DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( using accscalar_t = at::acc_type<scalar_t_0, true>;
p.DATA_PTR<accscalar_t>(), adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL, p.DATA_PTR<accscalar_t>(),
m.DATA_PTR<accscalar_t>(), p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
v.DATA_PTR<accscalar_t>(), m.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(), v.DATA_PTR<accscalar_t>(),
beta1, g.DATA_PTR<scalar_t_0>(),
beta2, beta1,
eps, beta2,
grad_scale, eps,
step_size, grad_scale,
tsize, step_size,
(adamMode_t) mode, tsize,
decay); (adamMode_t) mode,
); decay);
);
} else {
AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.DATA_PTR<uint8_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
} else { } else {
using namespace at; using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(), p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(), m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(), v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(), g.DATA_PTR<scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
grad_scale, grad_scale,
step_size, step_size,
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
); );
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
}
void unpack_e5m2_cuda(
at::Tensor & p_in,
at::Tensor & p_out)
{
//Get tensor size
int tsize = p_in.numel();
AT_ASSERTM(tsize == p_out.numel(), "p_in.numel() must equal p_out.numel()");
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32");
//Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_ASSERTM(p_in.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
AT_ASSERTM(p_out.scalar_type() == at::ScalarType::Half, "expected parameter to be of half type");
DISPATCH_FLOAT_AND_HALF(p_out.scalar_type(), 0, "unpack_e5m2",
unpack_e5m2_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<uint8_t>(),
p_out.DATA_PTR<scalar_t_0>(),
tsize);
);
THCudaCheck(cudaGetLastError());
} }
void fused_adam_undo_cuda( void fused_adam_undo_cuda(
......
...@@ -46,7 +46,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -46,7 +46,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm=False, distributed_weight_update=0, compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False, dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, internal_pipeline=False): dwu_num_chunks=4, predivide=True, internal_pipeline=False,
e5m2_allgather=False):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
...@@ -80,6 +81,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -80,6 +81,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._predivide = predivide self._predivide = predivide
self._internal_pipeline = internal_pipeline self._internal_pipeline = internal_pipeline
self._e5m2_allgather = e5m2_allgather
self._full_pipeline = full_pipeline self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
...@@ -306,7 +308,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -306,7 +308,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]): with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
if self._full_pipeline: if self._full_pipeline:
if self._new_params is None: if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads) if self._e5m2_allgather:
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8)
else:
self._new_params = torch.zeros_like(self._flat_grads)
self._pipeline_block(block_id, self._flat_grads, self._new_params) self._pipeline_block(block_id, self._flat_grads, self._new_params)
else: else:
self._pipeline_block_reductions(block_id, self._flat_grads) self._pipeline_block_reductions(block_id, self._flat_grads)
...@@ -539,7 +544,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -539,7 +544,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._last_step or not self._overlap_reductions or not self._full_pipeline: if self._last_step or not self._overlap_reductions or not self._full_pipeline:
if self._new_params is None: if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads) if self._e5m2_allgather:
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8)
else:
self._new_params = torch.zeros_like(self._flat_grads)
for inv_block_id in range(self._num_blocks): for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1 block_id = self._num_blocks - inv_block_id - 1
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]): with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
...@@ -551,7 +559,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -551,7 +559,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Check for overflow # Check for overflow
# Store state for loss scaler calculation # Store state for loss scaler calculation
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size) if self._e5m2_allgather:
new_params = torch.empty_like(self._flat_grads)
fused_adam_cuda.unpack_e5m2(self._new_params, new_params)
else:
new_params = self._new_params
self.strided_check_finite(new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if self.peek_overflow: if self.peek_overflow:
print("Reverting step") print("Reverting step")
self.revert_step() self.revert_step()
...@@ -569,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -569,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
nels = p.numel() nels = p.numel()
offset = self._grads_info[param_i]['param_offset'] offset = self._grads_info[param_i]['param_offset']
p.set_(self._new_params[offset:offset+nels].view_as(p)) p.set_(new_params[offset:offset+nels].view_as(p))
param_i += 1 param_i += 1
self._new_params = None self._new_params = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment