Unverified Commit 854b8890 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix max size in adam (#27)

* Update adam_kernel.cu

* Update adam_kernel.cu
parent 8e9e1c89
......@@ -8,10 +8,12 @@
#include "ATen/TensorUtils.h"
#include "ATen/AccumulateType.h"
#include <ATen/cuda/Exceptions.h>
#include <limits>
#include <cstdint>
#include "type_shim.h"
template <typename T, typename GRAD_T>
template <typename T, typename GRAD_T, typename SIZE_T>
__global__ void adam_cuda_kernel(
GRAD_T* __restrict__ p,
T* __restrict__ m,
......@@ -22,17 +24,17 @@ __global__ void adam_cuda_kernel(
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
const SIZE_T tsize,
const float decay_size)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
const SIZE_T blockId = static_cast<SIZE_T>(gridDim.x) * blockIdx.y + blockIdx.x;
const SIZE_T threadsPerBlock = static_cast<SIZE_T>(blockDim.x) * blockDim.y;
const SIZE_T threadIdInBlock = static_cast<SIZE_T>(threadIdx.y) * blockDim.x + threadIdx.x;
const SIZE_T i = (blockId * threadsPerBlock + threadIdInBlock);
const SIZE_T totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
for (SIZE_T j = i; j < tsize; j+=totThreads) {
// weight decay
T cur_p = (T)p[j] * decay_size;
T scaled_grad = static_cast<T>(g[j]) / grad_scale;
......@@ -58,11 +60,11 @@ void fused_adam_cuda(
float decay)
{
//Get tensor size
int tsize = p.numel();
size_t tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = lr;
if (bias_correction == 1) {
......@@ -79,9 +81,26 @@ void fused_adam_cuda(
if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
AT_ASSERTM(p.scalar_type() == g.scalar_type(), "expected parameter to be the same type as grad");
using namespace at; // prevents "toString is undefined" errors
if (tsize < std::numeric_limits<int32_t>::max()) {
DISPATCH_FLOAT_AND_HALF_AND_BF16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
adam_cuda_kernel<accscalar_t, scalar_t_0, int32_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(),
m.data_ptr<accscalar_t>(),
v.data_ptr<accscalar_t>(),
g.data_ptr<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
static_cast<int32_t>(tsize),
decay_size);
);
} else {
DISPATCH_FLOAT_AND_HALF_AND_BF16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0, size_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(),
m.data_ptr<accscalar_t>(),
v.data_ptr<accscalar_t>(),
......@@ -94,10 +113,27 @@ void fused_adam_cuda(
tsize,
decay_size);
);
}
} else {
using namespace at;
if (tsize < std::numeric_limits<int32_t>::max()) {
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
adam_cuda_kernel<scalar_t_0, scalar_t_0, int32_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(),
m.data_ptr<scalar_t_0>(),
v.data_ptr<scalar_t_0>(),
g.data_ptr<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
static_cast<int32_t>(tsize),
decay_size);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0, size_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(),
m.data_ptr<scalar_t_0>(),
v.data_ptr<scalar_t_0>(),
......@@ -111,5 +147,6 @@ void fused_adam_cuda(
decay_size);
);
}
}
AT_CUDA_CHECK(cudaGetLastError());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment