Unverified Commit e1b7997a authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #753 from NVIDIA/revertable_fused_adam_with_mt_support

Reversible fused adam with mt support
parents cf50dc7c 758826fc
#include <torch/extension.h>
// CUDA forward declaration
void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);
void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void strided_check_finite(
at::Tensor& overflow_flag,
at::Tensor& p_copy,
int stride,
int clear_overflow_first
) {
CHECK_INPUT(p_copy);
fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first);
}
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
......@@ -25,8 +40,47 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void reversible_adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) {
CHECK_INPUT(p_in);
CHECK_INPUT(p_out);
int64_t num_elem = p_in.numel();
AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal");
maybe_cast_cuda(overflow_flag, p_in, p_out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
}
......@@ -384,3 +384,653 @@ void fused_adam_cuda_mt(
}
THCudaCheck(cudaGetLastError());
}
template <typename FROM_T, typename TO_T>
__device__ void convert(const FROM_T vi, TO_T& vo)
{
vo = static_cast<TO_T>(vi);
}
template <>
__device__ void convert(const float vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = vi;
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, float& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = static_cast<float>(t.as_half);
}
template <>
__device__ void convert(const at::Half vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = static_cast<float>(vi);
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, at::Half& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = t.as_half;
}
template <typename GRAD_T>
__global__ void strided_check_finite_cuda_kernel(
volatile int* noop_gmem,
GRAD_T* __restrict__ p_copy,
const size_t tsize,
int stride,
int clear_overflow_first)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;
if (clear_overflow_first) {
if (i == 0) {
*noop_gmem = 0;
}
__syncthreads();
}
for (int j = i; j < tsize; j+=totThreads) {
GRAD_T pi = p_copy[j];
if (!isfinite(pi)) {
*noop_gmem = 1;
}
}
}
template <>
__global__ void strided_check_finite_cuda_kernel(
volatile int* noop_gmem,
uint8_t* __restrict__ p_copy,
const size_t tsize,
int stride,
int clear_overflow_first)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;
if (clear_overflow_first) {
if (i == 0) {
*noop_gmem = 0;
}
__syncthreads();
}
for (int j = i; j < tsize; j+=totThreads) {
at::Half pi;
convert(p_copy[j], pi);
if (!isfinite(pi)) {
*noop_gmem = 1;
}
}
}
template <typename FROM_T, typename TO_T>
__global__ void maybe_cast_kernel(
volatile int* overflow_flag,
const FROM_T* p_in,
TO_T* p_out,
const size_t tsize)
{
if (overflow_flag && *overflow_flag != 0) return;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
FROM_T pi[ILP];
TO_T po[ILP];
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = 0;
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
p_out[j] = po[ii];
}
}
}
}
template <typename T, typename GRAD_T, typename REDU_T>
__global__ void reversible_adam_cuda_kernel(
T* __restrict__ p,
REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
bool overflow = false;
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;
vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
float update = (mi[ii]/denom) + (decay*pi[ii]);
pi[ii] = pi[ii] - (step_size*update);
} else {
overflow = true;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
if (p_copy != NULL) {
convert(pi[ii], p_copy[j]);
}
}
}
}
if (p_copy != NULL) {
__syncthreads();
if (overflow) {
convert(float(INFINITY), p_copy[0]);
}
}
}
template <typename T, typename GRAD_T>
__global__ void maybe_adam_undo_cuda_kernel(
volatile int* overflow_flag,
T* __restrict__ p,
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
// NB! Skip undo kernel when overflow flag is NOT set
if (overflow_flag && *overflow_flag == 0) return;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
int j = j_start + i*ILP;
if (j < tsize) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay);
mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1;
vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i*ILP;
if (j < tsize) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
}
}
}
}
template <int DEPTH, typename FROM_T, typename TO_T>
struct MaybeCastFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl)
{
if (overflow_flag && *overflow_flag != 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];
p_in += chunk_idx*chunk_size;
TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];
p_out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
int dim = chunk_size < n ? chunk_size : n;
FROM_T pi[ILP];
TO_T po[ILP];
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = FROM_T(0);
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
p_out[j] = po[ii];
}
}
}
}
};
void fused_strided_check_finite(
at::Tensor & overflow_flag,
at::Tensor & p_copy,
int stride,
int clear_overflow_first)
{
//Get tensor size
int tsize = p_copy.numel();
int niter = (tsize + stride - 1) / stride;
//Determine #threads and #blocks
const int threadsPerBlock = 512;
//In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set.
const dim3 blocks(clear_overflow_first ? 1 : (niter+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(),
tsize,
stride,
clear_overflow_first);
);
THCudaCheck(cudaGetLastError());
}
void fused_reversible_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.DATA_PTR<uint8_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
reversible_adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void maybe_cast_cuda(
at::Tensor & overflow_flag,
at::Tensor & p_in,
at::Tensor & p_out)
{
//Get tensor size
int tsize = p_in.numel();
AT_ASSERTM(tsize == p_out.numel(), "p_in.numel() must equal p_out.numel()");
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32");
//Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, "maybe_cast_cuda"
DISPATCH_FLOAT_HALF_AND_BYTE(p_out.scalar_type(), 1, "maybe_cast_cuda",
maybe_cast_kernel<scalar_t_0,scalar_t_1><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p_in.DATA_PTR<scalar_t_0>(),
p_out.DATA_PTR<scalar_t_1>(),
tsize); ))
THCudaCheck(cudaGetLastError());
}
void maybe_cast_cuda_mt(
int chunk_size,
at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) // p_in, p_out
{
//Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 2, "expected tensor lists of size 2");
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[0][0].scalar_type(), 0, "maybe_cast_cuda_mt_kernel",
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, "maybe_cast_cuda_mt_kernel",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); ))
THCudaCheck(cudaGetLastError());
}
void fused_maybe_adam_undo_cuda(
at::Tensor & overflow_flag,
at::Tensor & p,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p.DATA_PTR<accscalar_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
maybe_adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p.DATA_PTR<scalar_t_0>(),
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = []
if self._overlap_reductions:
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for p, grads_info in zip(self._model_params, self._grads_info):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment.copy_(model_param_fragment)
p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out]
self._distributed_weight_update = distributed_weight_update # Is this still needed?
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Optionally compute L2 grad norm
if self._compute_L2_grad_norm and block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _pipeline_block_step(self, block_id):
# Call step kernel once per block
ag_stream = self._ag_st[block_id%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p_blocks[block_id],
self._fp16_p_blocks[block_id],
self._fp32_m_blocks[block_id],
self._fp32_v_blocks[block_id],
self._fp16_g_blocks[block_id])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if block_id == 0:
for other_ag_stream in self._ag_st:
self._completion_st.wait_stream(other_ag_stream)
with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _pipeline_step(self):
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p,
self._fp16_p,
self._fp32_m,
self._fp32_v,
self._fp16_g)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
if self._full_pipeline:
self._pipeline_block_step(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._has_overflow
self._has_overflow = False
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property
def L2_grad_norm(self):
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
else:
return None
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
if not grad_generated:
grad_info = self._grads_info[param_i]
param_offset = grad_info["param_offset"]
param_size = grad_info["param_grads_size"]
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def revert_step(self):
"""Revert effect of previously calling partial_step.
"""
# Call undo kernel once per step
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p,
self._fp32_m,
self._fp32_v,
self._fp16_g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
self._pipeline_step()
with torch.cuda.stream(self._completion_st):
# Check for overflow
# Store state for loss scaler calculation
has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if has_overflow:
self.revert_step()
else:
# Copy self._new_params to model params
for p in self._model_params: self.state[p]['step'] += 1
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params)
torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
return loss
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdamV2, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = []
if self._overlap_reductions:
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._block_size // self._group_size
self._chunk_size = self._shard_size // self._num_chunks
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d, self._chunk_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size,self._chunk_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._chunk_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_shards = [__shardify(block) for block in list_of_blocks]
list_of_list_of_list_of_chunks = [[__chunkify(shard) for shard in shards] for shards in list_of_list_of_shards]
return list_of_blocks, list_of_list_of_shards, list_of_list_of_list_of_chunks
self._flat_grads_blocks, self._flat_grads_shards, self._flat_grads_chunks = _flat_split(self._flat_grads)
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._chunk_size:(block_id+1)*self._num_chunks*self._chunk_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._chunk_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
# current arrangement
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._shard_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._new_params_mega_chunks [x self._num_chunks, self._shard_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._fp32_p_chunks [x self._num_chunks, self._shard_size]
# each chunk contains one shard
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# for chunk_id in range(self._num_chunks):
# works[chunk_id] = torch.distributed.reduce_scatter(self._flat_grads_chunks[block_id][chunk_id], self._fp16_g_chunks[block_id][chunk_id], ...)
#
# ----------------------------------------------------------------------------------------
#
# new arrangement
#
# NB! New equations for self._shard_size and self._chunk_size
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._chunk_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._new_params_mega_chunks [x self._num_chunks, self._chunk_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._fp32_p_chunks [x self._num_chunks, self._chunk_size]
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# work = torch.distributed.reduce_scatter(self._flat_grads_blocks[block_id], self._fp16_g[block_id], ...)
# for chunk_id in range(self._num_chunks):
# work.wait()
# works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id], ...)
# or
# work.wait()
# works[0] = torch.distributed.all_reduce(self._fp16_g_blocks[block_id], ...)
#
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
flat_shard_start = (block_id * self._group_size + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for p, grads_info in zip(self._model_params, self._grads_info):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_blocks[shard_id][block_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_blocks[block_id][shard_offset:shard_offset+grad_length]
print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment.copy_(model_param_fragment)
p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out]
self._distributed_weight_update = distributed_weight_update # Is this still needed?
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
rs_stream = self._rs_st[block_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
rs_work = torch.distributed.reduce_scatter(self._fp16_g_blocks[block_id],self._flat_grads_shards[block_id],group=self._rs_pg[block_id%self._num_rs_pg],async_op=True,no_copy=True)
for chunk_id in range(self._num_chunks):
works[chunk_id] = rs_work
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
rs_work.wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Optionally compute L2 grad norm
if self._compute_L2_grad_norm and block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _pipeline_block_step(self, block_id):
# Call step kernel once per block
ag_stream = self._ag_st[block_id%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p_blocks[block_id],
self._fp16_p_blocks[block_id],
self._fp32_m_blocks[block_id],
self._fp32_v_blocks[block_id],
self._fp16_g_blocks[block_id])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if block_id == 0:
for other_ag_stream in self._ag_st:
self._completion_st.wait_stream(other_ag_stream)
with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _pipeline_step(self):
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p,
self._fp16_p,
self._fp32_m,
self._fp32_v,
self._fp16_g)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
if self._full_pipeline:
self._pipeline_block_step(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._has_overflow
self._has_overflow = False
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property
def L2_grad_norm(self):
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
else:
return None
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
if not grad_generated:
grad_info = self._grads_info[param_i]
param_offset = grad_info["param_offset"]
param_size = grad_info["param_grads_size"]
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def revert_step(self):
"""Revert effect of previously calling partial_step.
"""
# Call undo kernel once per step
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p,
self._fp32_m,
self._fp32_v,
self._fp16_g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
self._pipeline_step()
with torch.cuda.stream(self._completion_st):
# Check for overflow
# Store state for loss scaler calculation
has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if has_overflow:
self.revert_step()
else:
# Copy self._new_params to model params
for p in self._model_params: self.state[p]['step'] += 1
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params)
torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
return loss
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdamV3(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdamV3, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = []
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._total_param_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._flat_params = torch.zeros_like(self._flat_grads)
def _flat_split(flat):
def __flat_blockify(flat):
return [flat[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __flat_shardify(flat):
return [flat[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
return __flat_blockify(flat), __flat_shardify(flat)
self._flat_grads_blocks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._flat_params_blocks, self._flat_params_shards = _flat_split(self._flat_params)
# master params
self._fp32_p = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
# copy model params to flat_params and set_ model params to flat_params.
self._individual_flat_grads = []
with torch.no_grad():
for p, grads_info in zip(self._model_params, self._grads_info):
start = grads_info["param_offset"]
end = start + grads_info["param_grads_size"]
flat_p = self._flat_params[start:end].view_as(p)
flat_p.copy_(p)
p.set_(flat_p)
flat_grad = self._flat_grads[start:end]
self._individual_flat_grads.append(flat_grad)
self._fp32_p.copy_(self._flat_params_shards[self._rank_in_group].float())
self._dwu_st = torch.cuda.Stream()
self._l2_grad_norm_st = torch.cuda.Stream()
for group_i in range(self._num_groups):
ranks = [group_i*self._group_size+local_rank for local_rank in range(self._group_size)]
pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg = pg
torch.distributed.all_reduce(self._overflow_buf, group=self._ag_pg)
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
@property
def has_overflow(self):
return True if not self.L2_grad_norm is None and not math.isfinite(self.L2_grad_norm) else False
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step and self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._dwu_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._dwu_st):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
torch.distributed.all_reduce(self._flat_grads_blocks[block_id])
if block_id == 0:
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, flat_grad in enumerate(self._individual_flat_grads):
if not self._grads_generated[param_i]:
flat_grad.zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
self._dwu_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._dwu_st):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
torch.distributed.all_reduce(self._flat_grads)
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
with torch.cuda.stream(self._dwu_st):
self.__launch_step_kernel(
self._fp32_p,
self._flat_params_shards[self._rank_in_group],
self._fp32_m,
self._fp32_v,
self._flat_grads_shards[self._rank_in_group])
torch.distributed.all_gather(self._flat_params_shards, self._flat_params_shards[self._rank_in_group], group=self._ag_pg, no_copy=True)
for p in self._model_params: self.state[p]['step'] += 1
torch.cuda.current_stream().wait_stream(self._dwu_st)
return loss
......@@ -34,6 +34,32 @@
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment