Commit c7b34549 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add no-flattening e5m2-allgather option

parent cd206434
......@@ -10,6 +10,7 @@ void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::v
void fused_adam_undo_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void unpack_e5m2_cuda(at::Tensor & p_in, at::Tensor & p_out);
void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
......@@ -64,7 +65,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2", &unpack_e5m2, "Unpack byte tensor containing e5m2 floats.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("adam_undo_mt", &fused_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2", &unpack_e5m2, "Unpack byte tensor containing e5m2 floats.");
m.def("unpack_e5m2_mt", &unpack_e5m2_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
}
......@@ -120,7 +120,37 @@ __global__ void strided_check_finite_cuda_kernel(
}
for (int j = i; j < tsize; j+=totThreads) {
GRAD_T pi = p_copy[j];
GRAD_T pi = p_copy[j];
if (!isfinite(pi)) {
*noop_gmem = 1;
}
}
}
template <>
__global__ void strided_check_finite_cuda_kernel(
volatile int* noop_gmem,
uint8_t* __restrict__ p_copy,
const size_t tsize,
int stride,
int clear_overflow_first)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;
if (clear_overflow_first) {
if (i == 0) {
*noop_gmem = 0;
}
__syncthreads();
}
for (int j = i; j < tsize; j+=totThreads) {
at::Half pi;
convert(p_copy[j], pi);
if (!isfinite(pi)) {
*noop_gmem = 1;
}
......@@ -337,6 +367,65 @@ __global__ void adam_undo_cuda_kernel(
}
}
template <int DEPTH, typename FROM_T, typename TO_T>
struct UnpackE5M2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl)
{
if (*noop_gmem != 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];
p_in += chunk_idx*chunk_size;
TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];
p_out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
int dim = chunk_size < n ? chunk_size : n;
FROM_T pi[ILP];
TO_T po[ILP];
bool overflow = false;
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = FROM_T(0);
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
if (!isfinite(po[ii])) {
overflow = true;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
p_out[j] = po[ii];
}
}
}
if (overflow) {
*noop_gmem = 1;
}
}
};
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
......@@ -533,7 +622,7 @@ void fused_strided_check_finite(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
noop.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(),
......@@ -669,6 +758,28 @@ void unpack_e5m2_cuda(
THCudaCheck(cudaGetLastError());
}
void unpack_e5m2_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) // p_in, p_out
{
//Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 2, "expected tensor lists of size 2");
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 0, "unpack_e5m2_cuda_mt_kernel",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
UnpackE5M2Functor<2, uint8_t, scalar_t_0>());
);
THCudaCheck(cudaGetLastError());
}
void fused_adam_undo_cuda(
at::Tensor & p,
at::Tensor & m,
......
import math
import torch
import importlib
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdam(torch.optim.Optimizer):
......@@ -559,17 +560,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Check for overflow
# Store state for loss scaler calculation
if self._e5m2_allgather:
new_params = torch.empty_like(self._flat_grads)
fused_adam_cuda.unpack_e5m2(self._new_params, new_params)
else:
new_params = self._new_params
self.strided_check_finite(new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if self.peek_overflow:
print("Reverting step")
self.revert_step()
else:
# Copy self._new_params to model params
if self._e5m2_allgather:
p_in = []
p_out = []
with torch.no_grad():
param_i = 0
for group in self.param_groups:
......@@ -582,8 +581,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
state['step'] += 1
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
p.set_(new_params[offset:offset+nels].view_as(p))
if self._e5m2_allgather:
p_in.append(self._new_params[offset:offset+nels].view_as(p))
p_out.append(p)
else:
p.set_(self._new_params[offset:offset+nels].view_as(p))
param_i += 1
if self._e5m2_allgather:
multi_tensor_applier(
fused_adam_cuda.unpack_e5m2_mt,
self._overflow_buf,
[p_in, p_out]);
self._new_params = None
torch.cuda.current_stream().wait_stream(self._blk_st[0])
......
......@@ -34,6 +34,32 @@
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment