Unverified Commit 854b8890 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix max size in adam (#27)

* Update adam_kernel.cu

* Update adam_kernel.cu
parent 8e9e1c89
...@@ -8,10 +8,12 @@ ...@@ -8,10 +8,12 @@
#include "ATen/TensorUtils.h" #include "ATen/TensorUtils.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <limits>
#include <cstdint>
#include "type_shim.h" #include "type_shim.h"
template <typename T, typename GRAD_T> template <typename T, typename GRAD_T, typename SIZE_T>
__global__ void adam_cuda_kernel( __global__ void adam_cuda_kernel(
GRAD_T* __restrict__ p, GRAD_T* __restrict__ p,
T* __restrict__ m, T* __restrict__ m,
...@@ -22,17 +24,17 @@ __global__ void adam_cuda_kernel( ...@@ -22,17 +24,17 @@ __global__ void adam_cuda_kernel(
const float eps, const float eps,
const float grad_scale, const float grad_scale,
const float step_size, const float step_size,
const size_t tsize, const SIZE_T tsize,
const float decay_size) const float decay_size)
{ {
//Assuming 2D grids and 2D blocks //Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const SIZE_T blockId = static_cast<SIZE_T>(gridDim.x) * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y; const SIZE_T threadsPerBlock = static_cast<SIZE_T>(blockDim.x) * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; const SIZE_T threadIdInBlock = static_cast<SIZE_T>(threadIdx.y) * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock); const SIZE_T i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; const SIZE_T totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) { for (SIZE_T j = i; j < tsize; j+=totThreads) {
// weight decay // weight decay
T cur_p = (T)p[j] * decay_size; T cur_p = (T)p[j] * decay_size;
T scaled_grad = static_cast<T>(g[j]) / grad_scale; T scaled_grad = static_cast<T>(g[j]) / grad_scale;
...@@ -58,11 +60,11 @@ void fused_adam_cuda( ...@@ -58,11 +60,11 @@ void fused_adam_cuda(
float decay) float decay)
{ {
//Get tensor size //Get tensor size
int tsize = p.numel(); size_t tsize = p.numel();
//Determine #threads and #blocks //Determine #threads and #blocks
const int threadsPerBlock = 512; const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); //AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants //Constants
float step_size = lr; float step_size = lr;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -79,9 +81,26 @@ void fused_adam_cuda( ...@@ -79,9 +81,26 @@ void fused_adam_cuda(
if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) { if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
AT_ASSERTM(p.scalar_type() == g.scalar_type(), "expected parameter to be the same type as grad"); AT_ASSERTM(p.scalar_type() == g.scalar_type(), "expected parameter to be the same type as grad");
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
if (tsize < std::numeric_limits<int32_t>::max()) {
DISPATCH_FLOAT_AND_HALF_AND_BF16(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF_AND_BF16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0, int32_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(),
m.data_ptr<accscalar_t>(),
v.data_ptr<accscalar_t>(),
g.data_ptr<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
static_cast<int32_t>(tsize),
decay_size);
);
} else {
DISPATCH_FLOAT_AND_HALF_AND_BF16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0, size_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(), p.data_ptr<scalar_t_0>(),
m.data_ptr<accscalar_t>(), m.data_ptr<accscalar_t>(),
v.data_ptr<accscalar_t>(), v.data_ptr<accscalar_t>(),
...@@ -94,10 +113,27 @@ void fused_adam_cuda( ...@@ -94,10 +113,27 @@ void fused_adam_cuda(
tsize, tsize,
decay_size); decay_size);
); );
}
} else { } else {
using namespace at; using namespace at;
if (tsize < std::numeric_limits<int32_t>::max()) {
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<scalar_t_0, scalar_t_0, int32_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(),
m.data_ptr<scalar_t_0>(),
v.data_ptr<scalar_t_0>(),
g.data_ptr<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
static_cast<int32_t>(tsize),
decay_size);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0, size_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(), p.data_ptr<scalar_t_0>(),
m.data_ptr<scalar_t_0>(), m.data_ptr<scalar_t_0>(),
v.data_ptr<scalar_t_0>(), v.data_ptr<scalar_t_0>(),
...@@ -111,5 +147,6 @@ void fused_adam_cuda( ...@@ -111,5 +147,6 @@ void fused_adam_cuda(
decay_size); decay_size);
); );
} }
}
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment