fused_adam_cuda_kernel.cu 4.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>

13
14
#include "type_shim.h"

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
typedef enum{
    ADAM_MODE_0   =0, // eps under square root
    ADAM_MODE_1   =1  // eps outside square root
} adamMode_t;

template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
        T* __restrict__ p,
        GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
        T* __restrict__ m,
        T* __restrict__ v,
        const GRAD_T * __restrict__ g,
        const float b1,
        const float b2,
        const float eps,
        const float grad_scale,
        const float step_size,
        const size_t tsize,
Deyu Fu's avatar
Deyu Fu committed
33
        adamMode_t mode,
34
35
        const float decay)
{
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        //Assuming 2D grids and 2D blocks
        const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
        const int threadsPerBlock = blockDim.x * blockDim.y;
        const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
        const int i = (blockId * threadsPerBlock + threadIdInBlock);
        const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;

        for (int j = i; j < tsize; j+=totThreads) {
                T scaled_grad = g[j]/grad_scale;
                m[j] = b1*m[j] + (1-b1)*scaled_grad;
                v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
                float denom;
                if (mode == ADAM_MODE_0)
                    denom = sqrtf(v[j] + eps);
                else // Mode 1
                    denom = sqrtf(v[j]) + eps;
Deyu Fu's avatar
Deyu Fu committed
52
53
                float update = (m[j]/denom) + (decay*p[j]);
                p[j] = p[j] - (step_size*update);
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
                if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
        }
}

void fused_adam_cuda(
        at::Tensor & p,
        at::Tensor & p_copy,
        at::Tensor & m,
        at::Tensor & v,
        at::Tensor & g,
        float lr,
        float beta1,
        float beta2,
        float eps,
        float grad_scale,
        int step,
Deyu Fu's avatar
Deyu Fu committed
70
71
        int mode,
        int bias_correction,
72
73
74
        float decay)
{
//        using namespace at;
75
76
77
78
79
80
81
82

        //Get tensor size
        int tsize = p.numel();
        //Determine #threads and #blocks
        const int threadsPerBlock = 512;
        const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
        AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
        //Constants
Deyu Fu's avatar
Deyu Fu committed
83
84
85
86
87
88
89
90
91
        float step_size = 0;
        if (bias_correction == 1) {
            const float bias_correction1 = 1 - std::pow(beta1, step);
            const float bias_correction2 = 1 - std::pow(beta2, step);
            step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
        }
        else {
            step_size = lr;
        }
92
93
94
95
96
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();

        if (g.type().scalarType() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
            AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
Deyu Fu's avatar
Deyu Fu committed
97
//dispatch is done on the gradient type
98
            using namespace at; // prevents "toString is undefined" errors
Michael Carilli's avatar
Michael Carilli committed
99
            AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] {
100
101
102
103
104
105
106
107
108
109
110
111
112
                using accscalar_t = at::acc_type<scalar_t, true>;
                adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
                        p.data<accscalar_t>(),
                        p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
                        m.data<accscalar_t>(),
                        v.data<accscalar_t>(),
                        g.data<scalar_t>(),
                        beta1,
                        beta2,
                        eps,
                        grad_scale,
                        step_size,
                        tsize,
Deyu Fu's avatar
Deyu Fu committed
113
114
                        (adamMode_t) mode,
                        decay);
115
116
            }));
      } else {
117
            using namespace at;
Michael Carilli's avatar
Michael Carilli committed
118
            AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
119
120
121
122
123
124
125
126
127
128
129
130
                adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
                        p.data<scalar_t>(),
                        NULL, //don't output p_copy for fp32, it's wasted write
                        m.data<scalar_t>(),
                        v.data<scalar_t>(),
                        g.data<scalar_t>(),
                        beta1,
                        beta2,
                        eps,
                        grad_scale,
                        step_size,
                        tsize,
Deyu Fu's avatar
Deyu Fu committed
131
132
                        (adamMode_t) mode,
                        decay);
133
134
135
136
137
            }));
      }
      THCudaCheck(cudaGetLastError());

}