"vscode:/vscode.git/clone" did not exist on "66f2922028754c6a7af648e8d39c62a7d7a49659"
Commit 1a994e37 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

First commit

parent 80b90b9d
......@@ -2,8 +2,10 @@
// CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
......@@ -11,6 +13,15 @@ void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::v
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void strided_check_finite(
at::Tensor& noop,
at::Tensor& p_copy,
int stride,
int clear_overflow_first
) {
CHECK_INPUT(p_copy);
fused_strided_check_finite(noop, p_copy, stride, clear_overflow_first);
}
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
......@@ -25,8 +36,23 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
fused_adam_undo_cuda(p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("adam_undo_mt", &fused_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
}
......@@ -21,6 +21,36 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <typename GRAD_T>
__global__ void strided_check_finite_cuda_kernel(
volatile int* noop_gmem,
GRAD_T* __restrict__ p_copy,
const size_t tsize,
int stride,
int clear_overflow_first)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;
if (clear_overflow_first) {
if (i == 0) {
*noop_gmem = 0;
}
__syncthreads();
}
for (int j = i; j < tsize; j+=totThreads) {
GRAD_T pi = p_copy[j];
if (!isfinite(pi)) {
*noop_gmem = 1;
}
}
}
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
T* __restrict__ p,
......@@ -37,26 +67,148 @@ __global__ void adam_cuda_kernel(
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
bool overflow = False;
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i*ILP;
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;
vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*p[j]);
p[j] = p[j] - (step_size*update);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
denom = sqrtf(vi[ii]) + eps;
float update = (mi[ii]/denom) + (decay*pi[ii]);
pi[ii] = pi[ii] - (step_size*update);
} else {
overflow = True;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
if (p_copy != NULL) p_copy[j] = static_cast<GRAD_T>(pi[ii]);
}
}
}
if (p_copy != NULL) {
__syncthreads();
if (overflow) {
p_copy[0] = INFINITY;
}
}
}
template <typename T, typename GRAD_T>
__global__ __device__ void adam_undo_cuda_kernel(
T* __restrict__ p,
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
int j = j_start + i*ILP;
if (j < tsize) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i*ILP;
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay);
mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1;
vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vii[ii] = vii[i] >= 0.0f ? vi[ii] : 0.0f;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i*ILP;
if (j < tsize) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
}
}
}
}
template <int DEPTH, typename T, typename GRAD_T>
......@@ -93,59 +245,181 @@ struct AdamFunctor
}
n -= chunk_idx*chunk_size;
int dim = chunk_size < n ? chunk_size : n;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
bool overflow = False;
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < tsize) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;
vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
float update = (mi[ii]/denom) + (decay*pi[ii]);
pi[ii] = pi[ii] - (step_size*update);
} else {
overflow = True;
}
}
#pragma unroll
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < tsize) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
if (p_copy != NULL) p_copy[j] = static_cast<GRAD_T>(pi[ii]);
}
}
}
if (overflow) {
*noop_gmem = 1;
}
}
};
template <int DEPTH, typename T, typename GRAD_T>
struct AdamUndoFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
int dim = chunk_size < n ? chunk_size : n;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < tsize) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
denom = sqrtf(vi[ii]) + eps;
pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay);
mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1;
vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vii[ii] = vii[i] >= 0.0f ? vi[ii] : 0.0f;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < tsize) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
}
}
}
}
};
void fused_strided_check_finite(
at::Tensor & noop,
at::Tensor & p_copy,
int stride,
int clear_overflow_first)
{
//Get tensor size
int tsize = p_copy.numel();
int niter = (tsize + stride - 1) / stride;
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((niter+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
noop.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(),
tsize,
stride,
clear_overflow_first);
);
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
......@@ -227,6 +501,84 @@ void fused_adam_cuda(
}
void fused_adam_undo_cuda(
at::Tensor & p,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
......@@ -262,48 +614,118 @@ void fused_adam_cuda_mt(
//dich is done on the gradient type
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
} else {
if (tl_sz == 5) {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
}
THCudaCheck(cudaGetLastError());
}
void fused_adam_undo_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4, "expected tensor list of size 4");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_undo_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
AdamUndoFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -311,15 +733,15 @@ void fused_adam_cuda_mt(
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_undo_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
AdamUndoFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -327,8 +749,7 @@ void fused_adam_cuda_mt(
step_size,
(adamMode_t) mode,
decay);
);
}
);
}
THCudaCheck(cudaGetLastError());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment