fused_lamb_cuda_kernel.cu 15.1 KB
Newer Older
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>

#include <iostream>

//#include <helper_functions.h>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include <stdio.h>

namespace cg = cooperative_groups;

// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
    // This is the un-specialized struct.  Note that we prevent instantiation of this
    // struct by putting an undefined symbol in the function body so it won't compile.
    template <typename T>
    struct SharedMemory
    {
        // Ensure that we won't compile any un-specialized types
        __device__ inline operator T *()
        {
            extern __device__ void error(void);
            error();
            return NULL;
        }
    };

    template <>
    struct SharedMemory <float>
    {
        __device__ inline operator float *()
        {
            extern __shared__ float s_float[];
            return s_float;
        }
    };

    template <>
    struct SharedMemory <double>
    {
        __device__ inline operator double *()
        {
            extern __shared__ double s_double[];
            return s_double;
        }
    };
    }

#include "type_shim.h"

typedef enum{
    ADAM_MODE_0   =0, // eps under square root
    ADAM_MODE_1   =1  // eps outside square root
} adamMode_t;


//s_a and s_b are in shared memory
//g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void
reduce_block_in_shared_memory(T *s_a, T *s_b, T* g_a, T* g_b)
{
    // Handle to thread block group
    cg::thread_block cta = cg::this_thread_block();

    // perform block reduction in shared memory,
    unsigned int tid = cta.thread_rank();

    T a_sum = s_a[tid];
    T b_sum = s_b[tid];

    cg::sync(cta);

    // do reduction in shared mem
    if ((blockSize >= 512) && (tid < 256))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 256];
        s_b[tid] = b_sum = b_sum + s_b[tid + 256];

    }

    cg::sync(cta);

    if ((blockSize >= 256) && (tid < 128))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 128];
        s_b[tid] = b_sum = b_sum + s_b[tid + 128];

    }

    cg::sync(cta);

    if ((blockSize >= 128) && (tid < 64))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 64];
        s_b[tid] = b_sum = b_sum + s_b[tid + 64];

    }

    cg::sync(cta);

#if (__CUDA_ARCH__ >= 300 )
    if ( tid < 32 )
    {
        cg::coalesced_group active = cg::coalesced_threads();

        // Fetch final intermediate sum from 2nd warp
        if (blockSize >=  64)
        {
            a_sum = a_sum + s_a[tid + 32];
            b_sum = b_sum + s_b[tid + 32];
        }

        // Reduce final warp using shuffle
        for (int offset = warpSize/2; offset > 0; offset /= 2)
        {
             a_sum += active.shfl_down(a_sum, offset);
             b_sum += active.shfl_down(b_sum, offset);

        }
    }
#else
    if ((blockSize >= 64) && (tid < 32))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 32];
        s_b[tid] = b_sum = b_sum + s_b[tid + 32];

    }

    cg::sync(cta);

    if ((blockSize >= 32) && (tid < 16))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 16];
        s_b[tid] = b_sum = b_sum + s_b[tid + 16];

    }

    cg::sync(cta);

    if ((blockSize >= 16) && (tid < 8))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 8];
        s_b[tid] = b_sum = b_sum + s_b[tid + 8];

    }

    cg::sync(cta);

    if ((blockSize >= 8) && (tid < 4))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 4];
        s_b[tid] = b_sum = b_sum + s_b[tid + 4];

    }

    cg::sync(cta);

    if ((blockSize >= 4) && (tid < 2))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 2];
        s_b[tid] = b_sum = b_sum + s_b[tid + 2];

    }

    cg::sync(cta);

    if ((blockSize >= 2) && (tid < 1))
    {
        s_a[tid] = a_sum = a_sum + s_a[tid + 1];
        s_b[tid] = b_sum = b_sum + s_b[tid + 1];

    }

    cg::sync(cta);

#endif

    // write result for this block to global mem
    if (tid == 0){
        g_a[blockIdx.x] = (T)a_sum;
        g_b[blockIdx.x] = (T)b_sum;
    }
}

template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b){

    const int threadIdInBlock = cg::this_thread_block().thread_rank();

    T *s_a = SharedMemory<T>();
    T *s_b = SharedMemory<T>() + cg::this_thread_block().size();

    s_a[threadIdInBlock] = a;
    s_b[threadIdInBlock] = b;

    reduce_block_in_shared_memory<T,blockSize>(s_a, s_b ,g_a, g_b);

}


template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
        T* __restrict__ p,
        GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
        T* __restrict__ m,
        T* __restrict__ v,
        const GRAD_T * __restrict__ g,
        const float b1,
        const float b2,
        const float eps,
        const float grad_scale,
        const float step_size,
        const size_t tsize,
        adamMode_t mode,
        const float decay,
        T* __restrict__ w_l2_i,
        T* __restrict__ u_l2_i)
{
        //Assuming 2D grids and 2D blocks
        const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
        const int threadsPerBlock = blockDim.x * blockDim.y;
        const int threadIdInBlock = cg::this_thread_block().thread_rank();
        const int i = (blockId * threadsPerBlock + threadIdInBlock);
        const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;


        T reg_w = 0;
        T reg_u = 0;

        for (int j = i; j < tsize; j+=totThreads) {
                T scaled_grad = g[j]/grad_scale;
                T pj = p[j];
                m[j] = b1*m[j] + (1-b1)*scaled_grad;
                v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
                float denom;
                if (mode == ADAM_MODE_0)
                    denom = sqrtf(v[j] + eps);
                else // Mode 1
                    denom = sqrtf(v[j]) + eps;
                T update = (m[j]/denom) + (decay*p[j]);

                reg_u += update * update;
                reg_w += pj * pj;

        }

        reduce_two_vectors_in_register<T,blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}

template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(
    const size_t tsize,
    T* __restrict__ g_a,
    T* __restrict__ g_b)
{

    T *s_a = SharedMemory<T>() ;
    T *s_b = SharedMemory<T>() + cg::this_thread_block().size();

    const int threadIdInBlock = cg::this_thread_block().thread_rank();

    s_a[threadIdInBlock] = g_a[threadIdInBlock];
    s_b[threadIdInBlock] = g_b[threadIdInBlock];

    if (threadIdInBlock >= tsize){
        s_a[threadIdInBlock] = 0.0;
        s_b[threadIdInBlock] = 0.0;
    }

    reduce_block_in_shared_memory<T,blockSize>(s_a, s_b, g_a, g_b);
}


    template <typename T, typename GRAD_T>
    __global__ void lamb_cuda_kernel_part3(
        T* __restrict__ p,
        GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
        T* __restrict__ m,
        T* __restrict__ v,
        const GRAD_T * __restrict__ g,
        const float b1,
        const float b2,
        const float max_coeff,
        const float min_coeff,
        const float eps,
        const float grad_scale,
        const float step_size,
        const size_t tsize,
        adamMode_t mode,
        const float decay,
        T* __restrict__ w_l2_i,
        T* __restrict__ u_l2_i,
        T* __restrict__ lamb_coeff_val)
{

        //Assuming 2D grids and 2D blocks
        const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
        const int threadsPerBlock = blockDim.x * blockDim.y;
        const int threadIdInBlock = cg::this_thread_block().thread_rank();
        const int i = (blockId * threadsPerBlock + threadIdInBlock);
        const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;

        T reg_w = sqrtf(w_l2_i[0]);
        T reg_u = sqrtf(u_l2_i[0]);

        float lamb_coeff = 1.0;

        if (reg_w !=0 and reg_u !=0){
            lamb_coeff = reg_w/reg_u;
            if (lamb_coeff > max_coeff){
                lamb_coeff = max_coeff;
            }
            if (lamb_coeff < min_coeff){
                lamb_coeff = min_coeff;
            }
        }

        if(blockId == 0 and threadIdInBlock == 0)
        {
            lamb_coeff_val[0] = lamb_coeff;
            //printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
        }

        for (int j = i; j < tsize; j+=totThreads) {
            T pj = (float)p[j];
            T mj = m[j];
            T vj = v[j];
            float denom;
            if (mode == ADAM_MODE_0)
                denom = sqrtf(vj + eps);
            else // Mode 1
                denom = sqrtf(vj) + eps;
            T update = (mj/denom) + (decay*pj);

            pj = pj - (step_size * lamb_coeff * update);
            p[j] = pj;
            if (p_copy != NULL) p_copy[j] = (GRAD_T) pj;
    }
}

void fused_lamb_cuda(
        at::Tensor & p,
        at::Tensor & p_copy,
        at::Tensor & m,
        at::Tensor & v,
        at::Tensor & g,
        float lr,
        float beta1,
        float beta2,
        float max_coeff,
        float min_coeff,
        float eps,
        float grad_scale,
        int step,
        int mode,
        int bias_correction,
        float decay,
        at::Tensor & w_l2_i,
        at::Tensor & u_l2_i,
        at::Tensor & lamb_coeff)
{
//        using namespace at;

        //Get tensor size
        int tsize = p.numel();
        //Determine #threads and #blocks
        const int threadsPerBlock = 512;
        int num_blocks = (tsize+threadsPerBlock-1)/threadsPerBlock;
        if (num_blocks > 512) num_blocks=512;

        int smemsize = 0;

        if (p.type().scalarType() == at::ScalarType::Double)
            smemsize = 2 * threadsPerBlock * sizeof(double);
        else
            smemsize = 2 * threadsPerBlock * sizeof(float);

        const dim3 blocks(num_blocks);
        const dim3 threads(threadsPerBlock);

        AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
        //Constants
        float step_size = 0;
        if (bias_correction == 1) {
            const float bias_correction1 = 1 - std::pow(beta1, step);
            const float bias_correction2 = 1 - std::pow(beta2, step);
            step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
        }
        else {
            step_size = lr;
        }
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();

        if (g.type().scalarType() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
            AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
            using namespace at; // prevents "toString is undefined" errors
            AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.scalar_type(), "lamb_cuda_kernel", ([&] {
                using accscalar_t = at::acc_type<scalar_t, true>;

                lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock><<<blocks,threadsPerBlock, smemsize, stream>>>(
                        p.data<accscalar_t>(),
                        p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
                        m.data<accscalar_t>(),
                        v.data<accscalar_t>(),
                        g.data<scalar_t>(),
                        beta1,
                        beta2,
                        eps,
                        grad_scale,
                        step_size,
                        tsize,
                        (adamMode_t) mode,
                        decay,
                        w_l2_i.data<accscalar_t>(),
                        u_l2_i.data<accscalar_t>());

                    lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock><<<1,threadsPerBlock, smemsize, stream>>>(
                        num_blocks,
                        w_l2_i.data<accscalar_t>(),
                        u_l2_i.data<accscalar_t>());

                     lamb_cuda_kernel_part3<accscalar_t, scalar_t><<<blocks,threadsPerBlock, smemsize, stream>>>(
                        p.data<accscalar_t>(),
                        p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
                        m.data<accscalar_t>(),
                        v.data<accscalar_t>(),
                        g.data<scalar_t>(),
                        beta1,
                        beta2,
                        max_coeff,
                        min_coeff,
                        eps,
                        grad_scale,
                        step_size,
                        tsize,
                        (adamMode_t) mode,
                        decay,
                        w_l2_i.data<accscalar_t>(),
                        u_l2_i.data<accscalar_t>(),
                        lamb_coeff.data<accscalar_t>());

            }));
      } else {
            using namespace at;
            AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "lamb_cuda_kernel", ([&] {

                lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock><<<blocks,threadsPerBlock, smemsize, stream>>>(
                        p.data<scalar_t>(),
                        NULL, //don't output p_copy for fp32, it's wasted write
                        m.data<scalar_t>(),
                        v.data<scalar_t>(),
                        g.data<scalar_t>(),
                        beta1,
                        beta2,
                        eps,
                        grad_scale,
                        step_size,
                        tsize,
                        (adamMode_t) mode,
                        decay,
                        w_l2_i.data<scalar_t>(),
                        u_l2_i.data<scalar_t>());


                 lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock><<<1,threadsPerBlock, smemsize, stream>>>(
                        num_blocks,
                        w_l2_i.data<scalar_t>(),
                        u_l2_i.data<scalar_t>());

                 lamb_cuda_kernel_part3<scalar_t, scalar_t><<<blocks,threadsPerBlock, smemsize, stream>>>(
                        p.data<scalar_t>(),
                        NULL, //don't output p_copy for fp32, it's wasted write
                        m.data<scalar_t>(),
                        v.data<scalar_t>(),
                        g.data<scalar_t>(),
                        beta1,
                        beta2,
                        max_coeff,
                        min_coeff,
                        eps,
                        grad_scale,
                        step_size,
                        tsize,
                        (adamMode_t) mode,
                        decay,
                        w_l2_i.data<scalar_t>(),
                        u_l2_i.data<scalar_t>(),
                        lamb_coeff.data<scalar_t>());

            }));
      }
      THCudaCheck(cudaGetLastError());

}

//template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a, float* g_b, cg::grid_group &cgg);