Commit 3f86316e authored by Wil Kong's avatar Wil Kong
Browse files

Add FusedAdam with multi-tensor apply support.

parent 6644c6e6
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
} }
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
#include "ATen/Type.h" #include "ATen/Type.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum{ typedef enum{
ADAM_MODE_0 =0, // eps under square root ADAM_MODE_0 =0, // eps under square root
...@@ -53,6 +57,93 @@ __global__ void adam_cuda_kernel( ...@@ -53,6 +57,93 @@ __global__ void adam_cuda_kernel(
} }
} }
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorList<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
};
void fused_adam_cuda( void fused_adam_cuda(
at::Tensor & p, at::Tensor & p,
at::Tensor & p_copy, at::Tensor & p_copy,
...@@ -129,3 +220,110 @@ void fused_adam_cuda( ...@@ -129,3 +220,110 @@ void fused_adam_cuda(
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].type().scalarType() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
}
} else {
if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
} else {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
}
}
THCudaCheck(cudaGetLastError());
}
...@@ -2,6 +2,8 @@ import types ...@@ -2,6 +2,8 @@ import types
import torch import torch
import importlib import importlib
from ..multi_tensor_apply import multi_tensor_applier
class FusedAdam(torch.optim.Optimizer): class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
...@@ -25,6 +27,8 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -25,6 +27,8 @@ class FusedAdam(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -35,10 +39,18 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -35,10 +39,18 @@ class FusedAdam(torch.optim.Optimizer):
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False): weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._use_multi_tensor = False
if use_mt:
if not multi_tensor_applier.available:
print("Warning: multi_tensor_applier is unavailable")
else:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0])
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
...@@ -105,6 +117,12 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -105,6 +117,12 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
if self._use_multi_tensor:
if output_params:
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None: if p.grad is None and grad is None:
...@@ -130,18 +148,43 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -130,18 +148,43 @@ class FusedAdam(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data, if self._use_multi_tensor:
out_p, pl = [p.data, exp_avg, exp_avg_sq, grad]
exp_avg, if output_param is not None:
exp_avg_sq, pl.append(out_p)
grad,
group['lr'], for tl, t in zip(tensorlists, pl):
beta1, tl.append(t)
beta2, else:
group['eps'], fused_adam_cuda.adam(p.data,
combined_scale, out_p,
state['step'], exp_avg,
self.eps_mode, exp_avg_sq,
bias_correction, grad,
group['weight_decay']) group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
...@@ -56,6 +56,7 @@ if "--cuda_ext" in sys.argv: ...@@ -56,6 +56,7 @@ if "--cuda_ext" in sys.argv:
'--use_fast_math']})) '--use_fast_math']}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
include_dirs=['csrc'],
sources=['apex/optimizers/csrc/fused_adam_cuda.cpp', sources=['apex/optimizers/csrc/fused_adam_cuda.cpp',
'apex/optimizers/csrc/fused_adam_cuda_kernel.cu'], 'apex/optimizers/csrc/fused_adam_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3',], extra_compile_args={'cxx': ['-O3',],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment