Commit cd206434 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add e5m2 allgather option

parent aa90d31f
......@@ -9,6 +9,7 @@ void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Te
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void unpack_e5m2_cuda(at::Tensor & p_in, at::Tensor & p_out);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
......@@ -50,11 +51,20 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
fused_adam_undo_cuda(p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void unpack_e5m2(at::Tensor & p_in, at::Tensor & p_out) {
CHECK_INPUT(p_in);
CHECK_INPUT(p_out);
int64_t num_elem = p_in.numel();
AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal");
unpack_e5m2_cuda(p_in, p_out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2", &unpack_e5m2, "Unpack byte tensor containing e5m2 floats.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("adam_undo_mt", &fused_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
}
......@@ -21,6 +21,82 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <typename FROM_T, typename TO_T>
__device__ void convert(const FROM_T vi, TO_T& vo)
{
vo = static_cast<TO_T>(vi);
}
template <>
__device__ void convert(const float vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = vi;
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, float& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = static_cast<float>(t.as_half);
}
template <>
__device__ void convert(const at::Half vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = static_cast<float>(vi);
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, at::Half& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = t.as_half;
}
template <typename GRAD_T>
__global__ void strided_check_finite_cuda_kernel(
volatile int* noop_gmem,
......@@ -44,17 +120,17 @@ __global__ void strided_check_finite_cuda_kernel(
}
for (int j = i; j < tsize; j+=totThreads) {
GRAD_T pi = p_copy[j];
GRAD_T pi = p_copy[j];
if (!isfinite(pi)) {
*noop_gmem = 1;
}
}
}
template <typename T, typename GRAD_T>
template <typename T, typename GRAD_T, typename REDU_T>
__global__ void adam_cuda_kernel(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
......@@ -122,7 +198,9 @@ __global__ void adam_cuda_kernel(
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
if (p_copy != NULL) p_copy[j] = static_cast<GRAD_T>(pi[ii]);
if (p_copy != NULL) {
convert(pi[ii], p_copy[j]);
}
}
}
}
......@@ -130,9 +208,59 @@ __global__ void adam_cuda_kernel(
if (p_copy != NULL) {
__syncthreads();
if (overflow) {
p_copy[0] = INFINITY;
convert(float(INFINITY), p_copy[0]);
}
}
}
template <typename GRAD_T>
__global__ void unpack_e5m2_kernel(
const uint8_t* p_in,
GRAD_T* p_out,
const size_t tsize)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
uint8_t pi[ILP];
GRAD_T po[ILP];
bool overflow = false;
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = 0;
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
if (!isfinite(po[ii])) {
overflow = true;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
p_out[j] = po[ii];
}
}
}
if (overflow) {
p_out[0] = INFINITY;
}
}
template <typename T, typename GRAD_T>
......@@ -404,15 +532,15 @@ void fused_strided_check_finite(
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
noop.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(),
tsize,
stride,
clear_overflow_first);
);
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
noop.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(),
tsize,
stride,
clear_overflow_first);
);
THCudaCheck(cudaGetLastError());
}
......@@ -432,69 +560,113 @@ void fused_adam_cuda(
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.DATA_PTR<uint8_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void unpack_e5m2_cuda(
at::Tensor & p_in,
at::Tensor & p_out)
{
//Get tensor size
int tsize = p_in.numel();
AT_ASSERTM(tsize == p_out.numel(), "p_in.numel() must equal p_out.numel()");
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32");
//Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_ASSERTM(p_in.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
AT_ASSERTM(p_out.scalar_type() == at::ScalarType::Half, "expected parameter to be of half type");
DISPATCH_FLOAT_AND_HALF(p_out.scalar_type(), 0, "unpack_e5m2",
unpack_e5m2_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<uint8_t>(),
p_out.DATA_PTR<scalar_t_0>(),
tsize);
);
THCudaCheck(cudaGetLastError());
}
void fused_adam_undo_cuda(
......
......@@ -46,7 +46,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, internal_pipeline=False):
dwu_num_chunks=4, predivide=True, internal_pipeline=False,
e5m2_allgather=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
......@@ -80,6 +81,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._internal_pipeline = internal_pipeline
self._e5m2_allgather = e5m2_allgather
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
......@@ -306,7 +308,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
if self._full_pipeline:
if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads)
if self._e5m2_allgather:
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8)
else:
self._new_params = torch.zeros_like(self._flat_grads)
self._pipeline_block(block_id, self._flat_grads, self._new_params)
else:
self._pipeline_block_reductions(block_id, self._flat_grads)
......@@ -539,7 +544,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads)
if self._e5m2_allgather:
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8)
else:
self._new_params = torch.zeros_like(self._flat_grads)
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
......@@ -551,7 +559,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Check for overflow
# Store state for loss scaler calculation
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if self._e5m2_allgather:
new_params = torch.empty_like(self._flat_grads)
fused_adam_cuda.unpack_e5m2(self._new_params, new_params)
else:
new_params = self._new_params
self.strided_check_finite(new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if self.peek_overflow:
print("Reverting step")
self.revert_step()
......@@ -569,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
state['step'] += 1
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
p.set_(self._new_params[offset:offset+nels].view_as(p))
p.set_(new_params[offset:offset+nels].view_as(p))
param_i += 1
self._new_params = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment