#include "ATen/ATen.h" #include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/detail/IndexUtils.cuh" #include #include #include #include #include "ATen/TensorUtils.h" #include "ATen/AccumulateType.h" #include #include #include #include "type_shim.h" template __global__ void adam_cuda_kernel( GRAD_T* __restrict__ p, T* __restrict__ m, T* __restrict__ v, const GRAD_T * __restrict__ g, const float b1, const float b2, const float eps, const float grad_scale, const float step_size, const SIZE_T tsize, const float decay_size) { //Assuming 2D grids and 2D blocks const SIZE_T blockId = static_cast(gridDim.x) * blockIdx.y + blockIdx.x; const SIZE_T threadsPerBlock = static_cast(blockDim.x) * blockDim.y; const SIZE_T threadIdInBlock = static_cast(threadIdx.y) * blockDim.x + threadIdx.x; const SIZE_T i = (blockId * threadsPerBlock + threadIdInBlock); const SIZE_T totThreads = gridDim.x*gridDim.y*threadsPerBlock; for (SIZE_T j = i; j < tsize; j+=totThreads) { // weight decay T cur_p = (T)p[j] * decay_size; T scaled_grad = static_cast(g[j]) / grad_scale; m[j] = b1*m[j] + (1-b1)*scaled_grad; v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad; const float update = m[j] / (sqrtf(v[j]) + eps); p[j] = cur_p - (step_size*update); } } void fused_adam_cuda( at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int bias_correction, float decay) { //Get tensor size size_t tsize = p.numel(); //Determine #threads and #blocks const int threadsPerBlock = 512; const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); //AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); //Constants float step_size = lr; if (bias_correction == 1) { const double bias_correction1 = 1.0 - std::pow(static_cast(beta1), step); const double bias_correction2 = 1.0 - std::pow(static_cast(beta2), step); step_size = static_cast(lr * std::sqrt(bias_correction2) / bias_correction1); } float decay_size = 1.0; if (decay != 0.0) { decay_size = 1.0 - step_size * decay; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) { AT_ASSERTM(p.scalar_type() == g.scalar_type(), "expected parameter to be the same type as grad"); using namespace at; // prevents "toString is undefined" errors if (tsize < std::numeric_limits::max()) { DISPATCH_FLOAT_AND_HALF_AND_BF16(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; adam_cuda_kernel<<>>( p.data_ptr(), m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, static_cast(tsize), decay_size); ); } else { DISPATCH_FLOAT_AND_HALF_AND_BF16(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; adam_cuda_kernel<<>>( p.data_ptr(), m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, decay_size); ); } } else { using namespace at; if (tsize < std::numeric_limits::max()) { DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", adam_cuda_kernel<<>>( p.data_ptr(), m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, static_cast(tsize), decay_size); ); } else { DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", adam_cuda_kernel<<>>( p.data_ptr(), m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, decay_size); ); } } AT_CUDA_CHECK(cudaGetLastError()); }