fused_adam_cuda_kernel.cu 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
12
13
14
15
#include "multi_tensor_apply.cuh"

#define BLOCK_SIZE 512
#define ILP 4
16

17
18
#include "type_shim.h"

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
typedef enum{
    ADAM_MODE_0   =0, // eps under square root
    ADAM_MODE_1   =1  // eps outside square root
} adamMode_t;

template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
        T* __restrict__ p,
        GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
        T* __restrict__ m,
        T* __restrict__ v,
        const GRAD_T * __restrict__ g,
        const float b1,
        const float b2,
        const float eps,
        const float grad_scale,
        const float step_size,
        const size_t tsize,
Deyu Fu's avatar
Deyu Fu committed
37
        adamMode_t mode,
38
39
        const float decay)
{
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        //Assuming 2D grids and 2D blocks
        const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
        const int threadsPerBlock = blockDim.x * blockDim.y;
        const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
        const int i = (blockId * threadsPerBlock + threadIdInBlock);
        const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;

        for (int j = i; j < tsize; j+=totThreads) {
                T scaled_grad = g[j]/grad_scale;
                m[j] = b1*m[j] + (1-b1)*scaled_grad;
                v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
                float denom;
                if (mode == ADAM_MODE_0)
                    denom = sqrtf(v[j] + eps);
                else // Mode 1
                    denom = sqrtf(v[j]) + eps;
Deyu Fu's avatar
Deyu Fu committed
56
57
                float update = (m[j]/denom) + (decay*p[j]);
                p[j] = p[j] - (step_size*update);
58
59
60
61
                if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
        }
}

62
63
64
65
66
67
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
    __device__ __forceinline__ void operator()(
        int chunk_size,
        volatile int* noop_gmem,
Michael Carilli's avatar
Michael Carilli committed
68
        TensorListMetadata<DEPTH>& tl,
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        const float b1,
        const float b2,
        const float eps,
        const float grad_scale,
        const float step_size,
        adamMode_t mode,
        const float decay)
    {
        int tensor_loc = tl.block_to_tensor[blockIdx.x];
        int chunk_idx = tl.block_to_chunk[blockIdx.x];
        int n = tl.sizes[tensor_loc];

        T* p = (T *)tl.addresses[0][tensor_loc];
        p += chunk_idx*chunk_size;
        T* m = (T *)tl.addresses[1][tensor_loc];
        m += chunk_idx*chunk_size;
        T* v = (T *)tl.addresses[2][tensor_loc];
        v += chunk_idx*chunk_size;
        GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
        g += chunk_idx*chunk_size;
        GRAD_T* p_copy = NULL;
        if (DEPTH == 5) {
            p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
            p_copy += chunk_idx*chunk_size;
        }

        n -= chunk_idx*chunk_size;

        T incoming_p[ILP];
        T incoming_m[ILP];
        T incoming_v[ILP];
        T incoming_g[ILP];
Michael Carilli's avatar
Michael Carilli committed
101

102
103
104
        for(int i_start = 0;
            i_start < n && i_start < chunk_size;
            i_start += blockDim.x*ILP) {
Michael Carilli's avatar
Michael Carilli committed
105

106
107
108
109
110
111
            #pragma unroll
            for(int ii = 0; ii < ILP; ii++) {
                incoming_p[ii] = 0;
                incoming_m[ii] = 0;
                incoming_v[ii] = 0;
                incoming_g[ii] = 0;
Michael Carilli's avatar
Michael Carilli committed
112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
                int i = i_start + threadIdx.x + ii*blockDim.x;
                if (i < n && i < chunk_size) {
                    incoming_p[ii] = p[i];
                    incoming_m[ii] = m[i];
                    incoming_v[ii] = v[i];
                    incoming_g[ii] = static_cast<T>(g[i]);
                }
            }

            // note for clarification to future michael:
            // From a pure memory dependency perspective, there's likely no point unrolling
            // the write loop, since writes just fire off once their LDGs arrive.
            // Put another way, the STGs are dependent on the LDGs, but not on each other.
            // There is still compute ILP benefit from unrolling the loop though.
            #pragma unroll
            for(int ii = 0; ii < ILP; ii++) {
                int j = i_start + threadIdx.x + ii*blockDim.x;

                if(j < n && j < chunk_size) {
                    T scaled_grad = incoming_g[ii]/grad_scale;
                    m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
                    v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
                    float denom;
                    if (mode == ADAM_MODE_0)
                        denom = sqrtf(v[j] + eps);
                    else // Mode 1
                        denom = sqrtf(v[j]) + eps;
                    float update = (m[j]/denom) + (decay*incoming_p[ii]);
                    p[j] = incoming_p[ii] - (step_size*update);
                    if (DEPTH == 5)  p_copy[j] = (GRAD_T) p[j];
                }
            }
        }
    }
};

149
150
151
152
153
154
155
156
157
158
159
160
void fused_adam_cuda(
        at::Tensor & p,
        at::Tensor & p_copy,
        at::Tensor & m,
        at::Tensor & v,
        at::Tensor & g,
        float lr,
        float beta1,
        float beta2,
        float eps,
        float grad_scale,
        int step,
Deyu Fu's avatar
Deyu Fu committed
161
162
        int mode,
        int bias_correction,
163
164
165
        float decay)
{
//        using namespace at;
166
167
168
169
170
171
172
173

        //Get tensor size
        int tsize = p.numel();
        //Determine #threads and #blocks
        const int threadsPerBlock = 512;
        const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
        AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
        //Constants
Deyu Fu's avatar
Deyu Fu committed
174
175
176
177
178
179
180
181
182
        float step_size = 0;
        if (bias_correction == 1) {
            const float bias_correction1 = 1 - std::pow(beta1, step);
            const float bias_correction2 = 1 - std::pow(beta2, step);
            step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
        }
        else {
            step_size = lr;
        }
183
184
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();

185
        if (g.scalar_type() == at::ScalarType::Half) {
186
//all other values should be fp32 for half gradients
187
            AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
Deyu Fu's avatar
Deyu Fu committed
188
//dispatch is done on the gradient type
189
            using namespace at; // prevents "toString is undefined" errors
190
            DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
191
192
                using accscalar_t = at::acc_type<scalar_t_0, true>;
                adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
193
                        p.data<accscalar_t>(),
194
                        p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
195
196
                        m.data<accscalar_t>(),
                        v.data<accscalar_t>(),
197
                        g.data<scalar_t_0>(),
198
199
200
201
202
203
                        beta1,
                        beta2,
                        eps,
                        grad_scale,
                        step_size,
                        tsize,
Deyu Fu's avatar
Deyu Fu committed
204
205
                        (adamMode_t) mode,
                        decay);
206
                )
207
      } else {
208
            using namespace at;
209
210
211
            DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
                adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
                        p.data<scalar_t_0>(),
212
                        NULL, //don't output p_copy for fp32, it's wasted write
213
214
215
                        m.data<scalar_t_0>(),
                        v.data<scalar_t_0>(),
                        g.data<scalar_t_0>(),
216
217
218
219
220
221
                        beta1,
                        beta2,
                        eps,
                        grad_scale,
                        step_size,
                        tsize,
Deyu Fu's avatar
Deyu Fu committed
222
223
                        (adamMode_t) mode,
                        decay);
224
            );
225
226
227
228
      }
      THCudaCheck(cudaGetLastError());

}
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

void fused_adam_cuda_mt(
    int chunk_size,
    at::Tensor noop_flag,
    std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
    float lr,
    float beta1,
    float beta2,
    float eps,
    float grad_scale,
    int step,
    int mode,
    int bias_correction,
    float decay) {

    //Constants
    float step_size = 0;
    if (bias_correction == 1) {
        const float bias_correction1 = 1 - std::pow(beta1, step);
        const float bias_correction2 = 1 - std::pow(beta2, step);
        step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
    }
    else {
        step_size = lr;
    }
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    size_t tl_sz = tensor_lists.size();
    AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");

259
    if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
260
//alher values should be fp32 for half gradients
261
        AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
//dich is done on the gradient type
        if (tl_sz == 5) {
            AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
                using accscalar_t = at::acc_type<scalar_t, true>;
                multi_tensor_apply<5>(
                    BLOCK_SIZE,
                    chunk_size,
                    noop_flag,
                    tensor_lists,
                    AdamFunctor<5, accscalar_t, scalar_t>(),
                    beta1,
                    beta2,
                    eps,
                    grad_scale,
                    step_size,
                    (adamMode_t) mode,
                    decay);
            }));
        } else {
            AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
                using accscalar_t = at::acc_type<scalar_t, true>;
                multi_tensor_apply<4>(
                    BLOCK_SIZE,
                    chunk_size,
                    noop_flag,
                    tensor_lists,
                    AdamFunctor<4, accscalar_t, scalar_t>(),
                    beta1,
                    beta2,
                    eps,
                    grad_scale,
                    step_size,
                    (adamMode_t) mode,
                    decay);
            }));
        }
    } else {
Michael Carilli's avatar
Michael Carilli committed
299
        if (tl_sz == 5) {
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
            AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
                multi_tensor_apply<5>(
                    BLOCK_SIZE,
                    chunk_size,
                    noop_flag,
                    tensor_lists,
                    AdamFunctor<5, scalar_t, scalar_t>(),
                    beta1,
                    beta2,
                    eps,
                    grad_scale,
                    step_size,
                    (adamMode_t) mode,
                    decay);
            }));
        } else {
            AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
                multi_tensor_apply<4>(
                    BLOCK_SIZE,
                    chunk_size,
                    noop_flag,
                    tensor_lists,
                    AdamFunctor<4, scalar_t, scalar_t>(),
                    beta1,
                    beta2,
                    eps,
                    grad_scale,
                    step_size,
                    (adamMode_t) mode,
                    decay);
            }));
        }
    }
    THCudaCheck(cudaGetLastError());
}