Unverified Commit 81a2cf04 authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[feat] remove support for non-multitensor Adam


Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 57079b08
#include <torch/extension.h> #include <torch/extension.h>
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &fused_adam_cuda, "Multi tensor Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
} }
...@@ -21,43 +21,7 @@ typedef enum{ ...@@ -21,43 +21,7 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root ADAM_MODE_1 =1 // eps outside square root
} adamMode_t; } adamMode_t;
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
GRAD_T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*p[j]);
p[j] = (GRAD_T)((float)p[j] - (step_size*update));
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}
template <int DEPTH, typename T, typename GRAD_T> template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor struct AdamFunctor
...@@ -147,87 +111,6 @@ struct AdamFunctor ...@@ -147,87 +111,6 @@ struct AdamFunctor
}; };
void fused_adam_cuda( void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Half, "expected parameter to be of half type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
......
...@@ -55,13 +55,10 @@ try: ...@@ -55,13 +55,10 @@ try:
weight_decay: Optional[float] = 0.0, weight_decay: Optional[float] = 0.0,
max_grad_norm: Optional[float] = 0.0, max_grad_norm: Optional[float] = 0.0,
amsgrad: Optional[bool] = False, amsgrad: Optional[bool] = False,
use_mt: Optional[bool] = True,
): ):
self._use_multi_tensor = False self._use_multi_tensor = False
if use_mt: self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
if amsgrad: if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.") raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
...@@ -131,51 +128,30 @@ try: ...@@ -131,51 +128,30 @@ try:
state["step"] += 1 state["step"] += 1
out_p = torch.tensor([]) out_p = torch.tensor([])
if self._use_multi_tensor: pl = [p.data, exp_avg, exp_avg_sq, grad]
pl = [p.data, exp_avg, exp_avg_sq, grad]
if p.device not in tensorlists:
if p.device not in tensorlists: tensorlists[p.device] = [[], [], [], []]
tensorlists[p.device] = [[], [], [], []]
for tl, t in zip(tensorlists[p.device], pl):
for tl, t in zip(tensorlists[p.device], pl): tl.append(t)
tl.append(t)
for tensordevice, tensorlist in tensorlists.items():
else: with torch.cuda.device(tensordevice):
with torch.cuda.device(p.device): fused_adam_cuda.adam(
fused_adam_cuda.adam( 2048 * 32,
p.data, self._overflow_buf,
out_p, tensorlist,
exp_avg, group["lr"],
exp_avg_sq, beta1,
grad, beta2,
group["lr"], group["eps"],
beta1, scale,
beta2, state["step"],
group["eps"], self.eps_mode,
scale, bias_correction,
state["step"], group["weight_decay"],
self.eps_mode, )
bias_correction,
group["weight_decay"],
)
if self._use_multi_tensor:
for tensordevice, tensorlist in tensorlists.items():
with torch.cuda.device(tensordevice):
fused_adam_cuda.adam_mt(
2048 * 32,
self._overflow_buf,
tensorlist,
group["lr"],
beta1,
beta2,
group["eps"],
scale,
state["step"],
self.eps_mode,
bias_correction,
group["weight_decay"],
)
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment