kernels.cu 144 KB
Newer Older
1
2
3
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
5
6
7
8
9
10
11
12
13
14
// LICENSE file in the root directory of this source tree.

#include <kernels.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include <math_constants.h>
Tim Dettmers's avatar
Tim Dettmers committed
15
16
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
Tim Dettmers's avatar
Tim Dettmers committed
17
#include <mma.h>
Tim Dettmers's avatar
Tim Dettmers committed
18

Tim Dettmers's avatar
Tim Dettmers committed
19
20
21
#include <cooperative_groups/memcpy_async.h>
#include <cuda/pipeline>

Tim Dettmers's avatar
Tim Dettmers committed
22
23
24
25
26
#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096

Tim Dettmers's avatar
Tim Dettmers committed
27
28
using namespace nvcuda;

Tim Dettmers's avatar
Tim Dettmers committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
__device__ float atomicMax(float* address, float val) {
  int* address_as_i = reinterpret_cast<int*>(address);
  int old = *address_as_i, assumed;
  do {
    assumed = old;
    old = atomicCAS(
        reinterpret_cast<int*>(address), assumed,
        __float_as_int(fmaxf(val, __int_as_float(assumed))));
  } while (assumed != old);
  return __int_as_float(old);
}

__device__ float atomicMin(float* address, float val) {
  int* address_as_i = reinterpret_cast<int*>(address);
  int old = *address_as_i, assumed;
  do {
    assumed = old;
    old = atomicCAS(
        reinterpret_cast<int*>(address), assumed,
        __float_as_int(fminf(val, __int_as_float(assumed))));
  } while (assumed != old);
  return __int_as_float(old);
}

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
__device__ float dDequantizeFP4(unsigned char val, float absmax)
{
  float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
  if((val & 0b0110) == 0)
  {
    // subnormal
    if((val & 0b0001) == 0)
      return 0.0f;
    else
      return sign*0.0625f*absmax;
  }
  else
  {
    // normal
    float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f);
    float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f;

    return sign*exponent*fraction*absmax;
  }
}

Tim Dettmers's avatar
Tim Dettmers committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
__device__ float d2DequantizeFP4(unsigned char val)
{
  float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
  if((val & 0b0110) == 0)
  {
    // subnormal
    if((val & 0b0001) == 0)
      return 0.0f;
    else
      return sign*0.0625f;
  }
  else
  {
    // normal
    float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f);
    float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f;

    return sign*exponent*fraction;
  }
}

Tim Dettmers's avatar
Tim Dettmers committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
{
  float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
  if((val & 0b0100) == 4) // 0
    if((val & 0b0010) == 2) //01
      if((val & 0b0001) == 1) // 111
        return 0.25000000f*absmax*sign; // 1111
      else
        return 0.16666667f*absmax*sign; // 1110
    else
      if((val & 0b0001) == 1) // 110
        return 0.50000000f*absmax*sign; // 1101
      else
        return 0.33333333f*absmax*sign; // 1100
  else
    if((val & 0b0010) == 2) //10
      if((val & 0b0001) == 1) // 101
        return 1.00000000f*absmax*sign; // 1011
      else
        return 0.66666667f*absmax*sign; // 1010
    else 
      if((val & 0b0001) == 1) // 100
        return 5.208333333e-03f*absmax*sign; // 1001
      else
        return 0.00000000f*absmax*sign; // 1000
}

123
124
125
126
127
128
129
130
131
132
133
134
135
136
__device__ unsigned char dQuantizeFP4(float x)
{
  // FP4 with bias of 3
  // first bit is a sign
  // subnormals
  // 0b000 = 0
  // 0b001 = 0.0625
  // 0b110 = 2
  // 0b111 = 3
  // 0b100 = 4
  // 0b101 = 6
  // 0b010 = 8
  // 0b011 = 12

Tim Dettmers's avatar
Tim Dettmers committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

  // we do a binary search
  // the pivots are divided by 12 (the FP4 absmax)
  // since we assum input data is in [-1.0, 1.0]

  // !be careful here, its easy to make a mistake
  // that is difficult to noice if you add an extra
  // zero somewhere!

  int sign = x < 0 ? 0b1000 : 0b0000;
  x = fabsf(x);
  if(x > 0.29166667f)
    if( x > 0.583333f)
      if( x > 0.8333333f)
        return 0b0011+sign;
      else
        return 0b0010+sign;
    else
      if(x > 0.4166667f)
        return 0b101+sign;
      else
        return 0b100+sign;
  else
    if(x > 0.0859375f)
      if(x > 0.20833333f)
        return 0b0111+sign;
      else
        return 0b0110+sign;
    else
      if(x > 0.00260417f)
        return 0b0001+sign;
      else
        return 0b0000+sign;
}

Tim Dettmers's avatar
Tim Dettmers committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
__device__ half dhDequantizeNF4(unsigned char val)
{
  // the values for this tree was generated by test_normal_map_tree
  // in the file tests/test_functional.py
  if((val & 0b1000) == 8)
    if((val & 0b0100) == 4) // 1
      if((val & 0b0010) == 2) // 11
        if((val & 0b0001) == 1) // 111
          return 1.0f; 
        else
          return 0.7229568362236023f;
      else
        if((val & 0b0001) == 1) // 110
          return 0.5626170039176941f; 
        else
          return 0.44070982933044434f; 
    else
      if((val & 0b0010) == 2) //10
        if((val & 0b0001) == 1) // 101
          return 0.33791524171829224f; 
        else
          return 0.24611230194568634f; 
      else 
        if((val & 0b0001) == 1) // 100
          return 0.16093020141124725f; 
        else
          return 0.07958029955625534f; 

  else
    if((val & 0b0100) == 4) // 0
      if((val & 0b0010) == 2) //01
        if((val & 0b0001) == 1) // 011
          return 0.0f; 
        else
          return -0.09105003625154495f; 
      else
        if((val & 0b0001) == 1) // 010
          return -0.18477343022823334f; 
        else
          return -0.28444138169288635f;
    else
      if((val & 0b0010) == 2) //00
        if((val & 0b0001) == 1) // 001
          return -0.39491748809814453f;
        else
          return -0.5250730514526367f; 
      else 
        if((val & 0b0001) == 1) // 000
          return -0.6961928009986877f; 
        else
          return -1.0f; 

}

__device__ float dDequantizeNF4(unsigned char val)
Tim Dettmers's avatar
Tim Dettmers committed
227
228
229
230
231
232
233
{
  // the values for this tree was generated by test_normal_map_tree
  // in the file tests/test_functional.py
  if((val & 0b1000) == 8)
    if((val & 0b0100) == 4) // 1
      if((val & 0b0010) == 2) // 11
        if((val & 0b0001) == 1) // 111
Tim Dettmers's avatar
Tim Dettmers committed
234
          return 1.0f; 
Tim Dettmers's avatar
Tim Dettmers committed
235
        else
Tim Dettmers's avatar
Tim Dettmers committed
236
          return 0.7229568362236023f;
Tim Dettmers's avatar
Tim Dettmers committed
237
238
      else
        if((val & 0b0001) == 1) // 110
Tim Dettmers's avatar
Tim Dettmers committed
239
          return 0.5626170039176941f; 
Tim Dettmers's avatar
Tim Dettmers committed
240
        else
Tim Dettmers's avatar
Tim Dettmers committed
241
          return 0.44070982933044434f; 
Tim Dettmers's avatar
Tim Dettmers committed
242
243
244
    else
      if((val & 0b0010) == 2) //10
        if((val & 0b0001) == 1) // 101
Tim Dettmers's avatar
Tim Dettmers committed
245
          return 0.33791524171829224f; 
Tim Dettmers's avatar
Tim Dettmers committed
246
        else
Tim Dettmers's avatar
Tim Dettmers committed
247
          return 0.24611230194568634f; 
Tim Dettmers's avatar
Tim Dettmers committed
248
249
      else 
        if((val & 0b0001) == 1) // 100
Tim Dettmers's avatar
Tim Dettmers committed
250
          return 0.16093020141124725f; 
Tim Dettmers's avatar
Tim Dettmers committed
251
        else
Tim Dettmers's avatar
Tim Dettmers committed
252
          return 0.07958029955625534f; 
Tim Dettmers's avatar
Tim Dettmers committed
253
254
255
256
257

  else
    if((val & 0b0100) == 4) // 0
      if((val & 0b0010) == 2) //01
        if((val & 0b0001) == 1) // 011
Tim Dettmers's avatar
Tim Dettmers committed
258
          return 0.0f; 
Tim Dettmers's avatar
Tim Dettmers committed
259
        else
Tim Dettmers's avatar
Tim Dettmers committed
260
          return -0.09105003625154495f; 
Tim Dettmers's avatar
Tim Dettmers committed
261
262
      else
        if((val & 0b0001) == 1) // 010
Tim Dettmers's avatar
Tim Dettmers committed
263
          return -0.18477343022823334f; 
Tim Dettmers's avatar
Tim Dettmers committed
264
        else
Tim Dettmers's avatar
Tim Dettmers committed
265
          return -0.28444138169288635f;
Tim Dettmers's avatar
Tim Dettmers committed
266
267
268
    else
      if((val & 0b0010) == 2) //00
        if((val & 0b0001) == 1) // 001
Tim Dettmers's avatar
Tim Dettmers committed
269
          return -0.39491748809814453f;
Tim Dettmers's avatar
Tim Dettmers committed
270
        else
Tim Dettmers's avatar
Tim Dettmers committed
271
          return -0.5250730514526367f; 
Tim Dettmers's avatar
Tim Dettmers committed
272
273
      else 
        if((val & 0b0001) == 1) // 000
Tim Dettmers's avatar
Tim Dettmers committed
274
          return -0.6961928009986877f; 
Tim Dettmers's avatar
Tim Dettmers committed
275
        else
Tim Dettmers's avatar
Tim Dettmers committed
276
          return -1.0f; 
Tim Dettmers's avatar
Tim Dettmers committed
277
278
279

}

280
__device__ unsigned char dQuantizeNF4(float x)
Tim Dettmers's avatar
Tim Dettmers committed
281
282
{

Tim Dettmers's avatar
Tim Dettmers committed
283
284
285
286
287
288
289
290
291
  // the values for this tree was generated by test_normal_map_tree
  // in the file tests/test_functional.py
  if(x > 0.03979014977812767f)
    if(x > 0.3893125355243683f) // 1
      if(x > 0.6427869200706482f) // 11
        if(x > 0.8614784181118011f) // 111
          return 0b1111;
        else
          return 0b1110;
292
      else
Tim Dettmers's avatar
Tim Dettmers committed
293
294
295
296
        if(x > 0.5016634166240692f) // 110
          return 0b1101;
        else
          return 0b1100;
297
    else
Tim Dettmers's avatar
Tim Dettmers committed
298
299
300
301
302
      if(x > 0.2035212516784668f) // 10
        if(x > 0.2920137718319893f) // 101
          return 0b1011;
        else
          return 0b1010;
303
      else
Tim Dettmers's avatar
Tim Dettmers committed
304
305
306
        if(x > 0.1202552504837513f) // 100
          return 0b1001;
        else
307
          return 0b1000;
308
  else
Tim Dettmers's avatar
Tim Dettmers committed
309
310
311
312
313
314
    if(x > -0.33967943489551544f) // 0
      if(x > -0.13791173323988914f) // 01
        if(x > -0.045525018125772476f) // 011
          return 0b0111;
        else
          return 0b0110;
315
      else
Tim Dettmers's avatar
Tim Dettmers committed
316
317
318
319
        if(x > -0.23460740596055984f) // 010
          return 0b0101;
        else
          return 0b0100;
320
    else
Tim Dettmers's avatar
Tim Dettmers committed
321
322
323
324
325
      if(x > -0.6106329262256622f) // 00
        if(x > -0.4599952697753906f) // 001
          return 0b0011;
        else
          return 0b0010;
326
      else
Tim Dettmers's avatar
Tim Dettmers committed
327
328
329
330
        if(x > -0.8480964004993439f) // 000
          return 0b0001;
        else
          return 0b0000;
331
332
}

Tim Dettmers's avatar
Tim Dettmers committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
    int pivot = 127;
    int upper_pivot = 255;
    int lower_pivot = 0;

    float lower = -1.0f;
    float upper = 1.0f;

    float val = smem_code[pivot];
    // i>>=1 = {32, 16, 8, 4, 2, 1}
    for(int i = 64; i > 0; i>>=1)
    {
        if(x > val)
        {
            lower_pivot = pivot;
            lower = val;
            pivot+=i;
        }
        else
        {
            upper_pivot = pivot;
            upper = val;
            pivot-=i;
        }
        val = smem_code[pivot];
    }

    if(upper_pivot == 255)
        upper = smem_code[upper_pivot];
    if(lower_pivot == 0)
        lower = smem_code[lower_pivot];

    if(!STOCHASTIC)
    {
      if(x > val)
      {
        float midpoint = (upper+val)*0.5f;
        if(x > midpoint)
        {
          return upper_pivot;
        }
        else
          return pivot;
      }
      else
      {
        float midpoint = (lower+val)*0.5f;
        if(x < midpoint)
          return lower_pivot;
        else
          return pivot;
      }
    }
    else
    {
      if(x > val)
      {
        float dist_to_upper = fabsf(upper-x);
        float dist_full = upper-val;
        if(rand >= dist_to_upper/dist_full) return upper_pivot;
        else return pivot;
      }
      else
      {
        float dist_to_lower = fabsf(lower-x);
        float dist_full = val-lower;
        if(rand >= dist_to_lower/dist_full) return lower_pivot;
        else return pivot;
      }
    }
}

template <int SIGNED>
__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x)
{
    int pivot = 127;
    int upper_pivot = 255;
    int lower_pivot = 0;

    float lower = SIGNED ? -1.0f : 0.0f;
    float upper = 1.0f;
    float midpoint;
    float val = quadrants[1];
    int local_pivot = 1;
    int offset = 1;

    // i>>=1 = {32, 16, 8, 4, 2, 1}
    for(int i = 64; i > 0; i>>=1)
    {
        if(x > val)
        {
            lower_pivot = pivot;
            lower = val;
            pivot+=i;
            //val = i == 64 ? quadrants[2] : smem_code[pivot];
            local_pivot += offset;
        }
        else
        {
            upper_pivot = pivot;
            upper = val;
            pivot-=i;
            //val = i == 64 ? quadrants[0] : smem_code[pivot];
            local_pivot -= offset;
        }
        val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot];
        offset -= 1;
    }

    if(x > val)
    {
      midpoint = (upper+val)*0.5f;
      if(x > midpoint)
        return upper_pivot;
      else
        return pivot;
    }
    else
    {
      midpoint = (lower+val)*0.5f;
      if(x < midpoint)
        return lower_pivot;
      else
        return pivot;
    }
}

template <int SIGNED>
__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper)
{
    int lower_pivot = QUADRANT*16-1 - 0;
    int pivot = QUADRANT*16-1 + 16;
    int upper_pivot = QUADRANT*16-1 + 31;

    float val = midpoint;

    // i>>=1 = {32, 16, 8, 4, 2, 1}
    for(int i = 16; i > 0; i>>=1)
    {
        if(x > val)
        {
            lower_pivot = pivot;
            lower = val;
            pivot+=i;
        }
        else
        {
            upper_pivot = pivot;
            upper = val;
            pivot-=i;
        }
        val = smem_code[pivot];
    }

    if(x > val)
    {
      midpoint = (upper+val)*0.5f;
      if(x > midpoint)
        return upper_pivot;
      else
        return pivot;
    }
    else
    {
      midpoint = (lower+val)*0.5f;
      if(x < midpoint)
        return lower_pivot;
      else
        return pivot;
    }
}

__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
{
  const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
  const int numThreads = blockDim.x*gridDim.x;

  for(int i = tid; i < n; i+=numThreads)
  {
      int idx = (index1[i]*maxidx1) + index2[i];
      atomicAdd(&histogram[idx], src[i]);
  }
}

template<typename T, int BLOCK_SIZE, int NUM_MAX>
__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n)
{
  typedef cub::WarpReduce<T> WarpReduce;
  __shared__ typename WarpReduce::TempStorage temp_storage;
  typedef cub::BlockLoad<T, BLOCK_SIZE/8 , 8, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
  __shared__ typename LoadT::TempStorage loadt;

  const int warp_idx = threadIdx.x/32;
  const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE);

  //  BLOCK_SIZE/32 == number of warps
  __shared__ int smem_max_indices[8*BLOCK_SIZE/32];
  __shared__ float smem_max_values[8*BLOCK_SIZE/32];

  T values[8];
  T max1 = -64000.0f;
  T max2 = -64000.0f;
  int max_idx1 = -1;
  int max_idx2 = -1;
  int sign1 = -1;
  int sign2 = -1;

  // 1. load 8 values per thread
  // 2. compute 2-max in registers (64 max per warp)
  // 3. do warp reduction + broadcast back
  // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
  // 5. Repeat (3) 8 times for top 8 values in 256
  // 6. store with byte index

  LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f);
  #pragma unroll 8
  for(int i = 0; i < 8; i++)
  {
    T absval = fabsf(values[i]);
    if(absval > max1)
    {
      max1 = values[i];
      sign1 = signbit(values[i]);
      max_idx1 = 8*threadIdx.x + i;
    }
    else if(absval > max2)
    {
      max2 = values[i];
      sign2 = signbit(values[i]);
      max_idx2 = 8*threadIdx.x + i;
    }
  }

  float warp_max;
  for(int i = 0; i < 8; i++)
  {
    // 3. do warp reduction + broadcast back
    warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max());
    warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff);

    // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
    if(warp_max == max1)
    {
      smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1;
      smem_max_indices[warp_idx*8 + i] = max_idx1;

      sign1 = sign2;
      max1 = max2;
      max_idx1 = max_idx2;

      max2 = -64000.0f;
    }
    __syncwarp();
  }

  if(threadIdx.x % 32 < 8)
  {
    // offset: 8 values per 256 input values
593
    //
Tim Dettmers's avatar
Tim Dettmers committed
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8;
  }

}

#define THREADS_ESTIMATE 512
#define NUM_ESTIMATE 8
#define BLOCK_ESTIMATE 4096

template<typename T>
__launch_bounds__(THREADS_ESTIMATE, 1)
__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n)
{
  const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE);
  int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE;
  const int base_idx = (blockIdx.x * BLOCK_ESTIMATE);
  const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE));

  T vals[NUM_ESTIMATE];

  typedef cub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
  typedef cub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;

  __shared__ union {
      typename LoadFloat::TempStorage loadf;
      typename BlockRadixSort::TempStorage sort;
      int smem_qidx[BLOCK_ESTIMATE];
  } temp_storage;

  for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE)
  {
      valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i;

      // do not process half-blocks
      if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; }

      #pragma unroll 4
      for(int j = 0; j < NUM_ESTIMATE; j++)
          vals[j] = max_val;

      __syncthreads();
      LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items);

      #pragma unroll 4
      for(int j = 0; j < NUM_ESTIMATE; j++)
          vals[j] = ((float)vals[j]) * reciprocal_num_blocks;


      __syncthreads();
      // sort into striped pattern to mitigate bank conflicts
      // striped pattern index for thread 0 [0, 1024, 2048, 3096]
      // striped pattern index for thread 1 [1, 1025, 2049, 3097]
      BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals);

      __syncthreads();
      for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x)
          temp_storage.smem_qidx[j] = -1;

      if(threadIdx.x < 256)
      {
          float q_interval = (1.0f-(2.0f*offset))/255.0f;
          int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1)));
          temp_storage.smem_qidx[local_idx] = threadIdx.x;
      }

      __syncthreads();

      for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x)
      {
          if(temp_storage.smem_qidx[i] != -1)
              atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]);
      }
  }
}


__launch_bounds__(TH, 4)
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
{
  const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
  int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK;
  const int base_idx = (blockIdx.x * NUM_BLOCK);

  float vals[NUM];
  unsigned char qvals[NUM];
  //const int lane_id = threadIdx.x % 2;

  typedef cub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
  typedef cub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;

  __shared__ typename LoadFloat::TempStorage loadf;
  __shared__ typename StoreChar::TempStorage storec;
  __shared__ float smem_code[256];
  //__shared__ float smem_code[2][257];

  if(threadIdx.x < 256)
  {
    smem_code[threadIdx.x] = code[threadIdx.x];
    //smem_code[0][threadIdx.x] = code[threadIdx.x];
    //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x];
  }


  for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK)
  {
      // number of values already processed in blocks +
      // number of values already processed in this block +
      // rand_offset % mod value
      valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;

      __syncthreads();
      LoadFloat(loadf).Load(&(A[i]), vals, valid_items);


      #pragma unroll 4
      for(int j = 0; j < NUM; j++)
          qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]);

      __syncthreads();
      StoreChar(storec).Store(&(out[i]), qvals, valid_items);
  }
}

Tim Dettmers's avatar
Tim Dettmers committed
717
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
718
//__launch_bounds__(TH, 4)
Tim Dettmers's avatar
Tim Dettmers committed
719
720
721
722
723
724
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
  const int n_full = gridDim.x * BLOCK_SIZE;
  int valid_items = 0;
  const int base_idx = (blockIdx.x * BLOCK_SIZE);

725
726
  T vals[NUM_PER_TH];
  float rand_vals[NUM_PER_TH];
Tim Dettmers's avatar
Tim Dettmers committed
727
  unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
Tim Dettmers's avatar
Tim Dettmers committed
728
729
730
731
732
  //float local_abs_max = -FLT_MAX;
  float local_abs_max = 0.0f;
  int local_rand_idx = 0;

  typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
Tim Dettmers's avatar
Tim Dettmers committed
733
  typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
Tim Dettmers's avatar
Tim Dettmers committed
734
735
736
737
738
739
740
741
742
743
  typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
  typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;

  __shared__ typename LoadT::TempStorage loadt;
  __shared__ typename LoadFloat::TempStorage loadf;
  __shared__ typename StoreChar::TempStorage storec;
  __shared__ typename BlockReduce::TempStorage reduce;
  __shared__ float smem_code[256];
  __shared__ float smem_absmax_value[1];

Tim Dettmers's avatar
Tim Dettmers committed
744
  if(DATA_TYPE == General8bit)
745
746
    for(int i = threadIdx.x; i < 256; i+=blockDim.x)
      smem_code[i] = code[i];
Tim Dettmers's avatar
Tim Dettmers committed
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785

  for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
  {
    valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
    local_abs_max = -FLT_MAX;

    __syncthreads();
    LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f);

    // 1. compute local max
    // 2. broadcast local max
    // 3. normalize inputs and quantize

    #pragma unroll NUM_PER_TH
    for(int j = 0; j < NUM_PER_TH; j++)
       local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));

    local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);

    if(threadIdx.x == 0)
      smem_absmax_value[0] = local_abs_max;

    __syncthreads();

    if(threadIdx.x == 0)
      absmax[i/BLOCK_SIZE] = local_abs_max;
    else
      local_abs_max = smem_absmax_value[0];

    __syncwarp();

    local_abs_max = 1.0f/local_abs_max;

    if(STOCHASTIC)
    {
      local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4);
      LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
    }

Tim Dettmers's avatar
Tim Dettmers committed
786
787
    unsigned char packed_4bit = 0;
    switch(DATA_TYPE)
Tim Dettmers's avatar
Tim Dettmers committed
788
    {
Tim Dettmers's avatar
Tim Dettmers committed
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
        case General8bit:
            #pragma unroll NUM_PER_TH
            for(int j = 0; j < NUM_PER_TH; j++)
            {
                if(!STOCHASTIC)
                 qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
                else
                 qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
            }
            break;
        case FP4:
            #pragma unroll NUM_PER_TH
            for(int j = 0; j < NUM_PER_TH/2; j++)
            {
              packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
              packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
              qvals[j] = packed_4bit;
            }
            break;
        case NF4:
            #pragma unroll NUM_PER_TH
            for(int j = 0; j < NUM_PER_TH/2; j++)
            {
812
813
              packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
              packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
Tim Dettmers's avatar
Tim Dettmers committed
814
815
816
              qvals[j] = packed_4bit;
            }
            break;
Tim Dettmers's avatar
Tim Dettmers committed
817
818
819
    }

    __syncthreads();
Tim Dettmers's avatar
Tim Dettmers committed
820
    StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
Tim Dettmers's avatar
Tim Dettmers committed
821
822
823
  }
}

824
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
825
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n)
Tim Dettmers's avatar
Tim Dettmers committed
826
827
{

828
829
830
831
  const int n_load = (gridDim.x * TILE_SIZE);
  int valid_items_load = 0;
  int valid_items_store = 0;
  const int base_idx = (blockIdx.x * TILE_SIZE);
Tim Dettmers's avatar
Tim Dettmers committed
832

833
  T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)];
834
  unsigned char qvals[NUM_PER_TH];
Tim Dettmers's avatar
Tim Dettmers committed
835
836
837
  float local_abs_max = -FLT_MAX;

  typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
838
  typedef cub::BlockStore<T, THREADS, NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
Tim Dettmers's avatar
Tim Dettmers committed
839
840
841
842

  __shared__ typename LoadChar::TempStorage loadchar;
  __shared__ typename StoreT::TempStorage storet;

843
  for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
Tim Dettmers's avatar
Tim Dettmers committed
844
  {
845
846
847
848
849
850
851
852
853
854
855
    if(DATA_TYPE > 0)
    {
      valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i;
      valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2;
    }
    else
    {
      valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i;
      valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i;
    }
    local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]);
Tim Dettmers's avatar
Tim Dettmers committed
856

857
858
    __syncthreads();
    LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
Tim Dettmers's avatar
Tim Dettmers committed
859

860

861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
    switch(DATA_TYPE)
    {
        case General8bit:
          // load code through read-only cache via __ldg
          #pragma unroll NUM_PER_TH
          for(int j = 0; j < NUM_PER_TH; j++)
            vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
          break;
        case FP4:
          #pragma unroll NUM_PER_TH
          for(int j = 0; j < NUM_PER_TH; j++)
          {
            vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
            vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
          }
          break;
        case NF4:
          #pragma unroll NUM_PER_TH
          for(int j = 0; j < NUM_PER_TH; j++)
          {
Tim Dettmers's avatar
Tim Dettmers committed
881
882
            vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max;
            vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
883
884
885
          }
          break;
    }
Tim Dettmers's avatar
Tim Dettmers committed
886

887
888
    __syncthreads();
    StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store);
Tim Dettmers's avatar
Tim Dettmers committed
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
  }
}

__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
{
	const unsigned int numThreads = blockDim.x * gridDim.x;
	const int idx = (blockIdx.x * blockDim.x) + threadIdx.x;

	__shared__ float smem_code[256];
	if(threadIdx.x < 256)
	{
		smem_code[threadIdx.x] = code[threadIdx.x];
	}

	__syncthreads();

	for (int i = idx;i < n; i += numThreads)
	{
		out[i] = smem_code[A[i]];
	}
}



template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
915
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
Tim Dettmers's avatar
Tim Dettmers committed
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
                float* state1, float* state2, float *unorm,
                const float beta1, const float beta2, const float eps, const float weight_decay,
                const int step, const float lr, const float gnorm_scale, const int n)
{

  const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
  const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
  int valid_items = 0;

  T g_vals[NUM_VALS];

  float s1_vals[NUM_VALS];
  float s2_vals[NUM_VALS];

  const float correction1 = 1.0f/(1.0f - powf(beta1, step));
  const float correction2 = 1.0f/(1.0f - powf(beta2, step));

  typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
  typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
  typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;

  __shared__ union {
      typename Load::TempStorage load;
      typename LoadFloat::TempStorage loadf;
      typename BlockReduce::TempStorage reduce;
  } temp_storage;

  for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
  {
      valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;

      __syncthreads();
      Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
      __syncthreads();
      LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
      __syncthreads();
      LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f);

      # pragma unroll NUM_VALS
      for(unsigned int j = 0; j < NUM_VALS; j++)
        g_vals[j] = gnorm_scale*((float)g_vals[j]);

      # pragma unroll NUM_VALS
      for(unsigned int j = 0; j < NUM_VALS; j++)
      {
          switch(OPTIMIZER)
          {
963
              case ADAM:
Tim Dettmers's avatar
Tim Dettmers committed
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
                  s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
                  s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
                  s1_vals[j] *= correction1;
                  s2_vals[j] *= correction2;
                  s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
                  s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update)
                  break;
          }
      }

      # pragma unroll NUM_VALS-1
      for(unsigned int j = 1; j < NUM_VALS; j++)
          s1_vals[0] += s1_vals[j];

      __syncthreads();
      s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]);

      if(threadIdx.x == 0)
        atomicAdd(&unorm[0], s1_vals[0]);

      __syncwarp();
  }
}



#define NUM_PER_THREAD 4

template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
994
__global__ void kOptimizer32bit2State(T* g, T* p,
Tim Dettmers's avatar
Tim Dettmers committed
995
996
                float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
                const float beta1, const float beta2, const float eps, const float weight_decay,
997
                const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
Tim Dettmers's avatar
Tim Dettmers committed
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
{

  const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
  const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
  int valid_items = 0;
  float update_scale = 0.0f;
  T g_vals[NUM_PER_THREAD];
  T p_vals[NUM_PER_THREAD];

  float s1_vals[NUM_PER_THREAD];
  float s2_vals[NUM_PER_THREAD];

  const float correction1 = 1.0f - powf(beta1, step);
  const float correction2 = sqrtf(1.0f - powf(beta2, step));
  const float step_size = -lr*correction2/correction1;

  if(max_unorm > 0.0f)
  {
    update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
    if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
    else{ update_scale = 1.0f; }
  }
  else{ update_scale = 1.0f; }

  typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
  typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;

  typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
  typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;

  __shared__ union {
      typename Load::TempStorage load;
      typename Store::TempStorage store;
      typename LoadFloat::TempStorage loadf;
      typename StoreFloat::TempStorage storef;
  } temp_storage;

  for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
  {
      valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;

      __syncthreads();
      Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
      __syncthreads();
      LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
      __syncthreads();
      LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items);
      __syncthreads();
      Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);

      # pragma unroll 4
      for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
        g_vals[j] = gnorm_scale*((float)g_vals[j]);

      # pragma unroll 4
      for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
      {
          switch(OPTIMIZER)
          {
1057
              case ADAM:
1058
									if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
1059
1060
1061
1062
									{
										s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
										s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
										p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
Tim Dettmers's avatar
Tim Dettmers committed
1063
1064
1065

                    if(weight_decay > 0.0f)
                        p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
1066
									}
Tim Dettmers's avatar
Tim Dettmers committed
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
                  break;
          }
      }

      __syncthreads();
      Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
      __syncthreads();
      StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
      __syncthreads();
      StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items);
  }
}

template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
1082
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
Tim Dettmers's avatar
Tim Dettmers committed
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
                float* state1, float *unorm,
                const float beta1, const float eps, const float weight_decay,
                const int step, const float lr, const float gnorm_scale, const int n)
{

  const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
  const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
  int valid_items = 0;

  T g_vals[NUM_VALS];

  float s1_vals[NUM_VALS];

  typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
  typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
  typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;

  __shared__ union {
      typename Load::TempStorage load;
      typename LoadFloat::TempStorage loadf;
      typename BlockReduce::TempStorage reduce;
  } temp_storage;

  for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
  {
      valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;

      __syncthreads();
      Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
      __syncthreads();
      LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);

      # pragma unroll NUM_VALS
      for(unsigned int j = 0; j < NUM_VALS; j++)
        g_vals[j] = gnorm_scale*((float)g_vals[j]);

      # pragma unroll NUM_VALS
      for(unsigned int j = 0; j < NUM_VALS; j++)
      {
          switch(OPTIMIZER)
          {
1124
              case MOMENTUM:
Tim Dettmers's avatar
Tim Dettmers committed
1125
1126
1127
1128
1129
1130
                  if(step == 1)
                    s1_vals[j] = (float)g_vals[j]; // state update
                  else
                    s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
                  s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
                  break;
1131
              case RMSPROP:
Tim Dettmers's avatar
Tim Dettmers committed
1132
1133
1134
1135
                  s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
                  s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
                  s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
                  break;
1136
              case ADAGRAD:
1137
1138
1139
1140
                  s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update
                  s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
                  s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
                  break;
Tim Dettmers's avatar
Tim Dettmers committed
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
          }
      }

      # pragma unroll
      for(unsigned int j = 1; j < NUM_VALS; j++)
        s1_vals[0] += s1_vals[j];

      __syncthreads();
      s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items);

      if(threadIdx.x == 0)
        atomicAdd(&unorm[0], s1_vals[0]);

      __syncwarp();
  }
}

template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
1160
__global__ void kOptimizer32bit1State(T *g, T *p,
Tim Dettmers's avatar
Tim Dettmers committed
1161
1162
                float *state1, float *unorm, const float max_unorm, const float param_norm,
                const float beta1, const float eps, const float weight_decay,
1163
                const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
Tim Dettmers's avatar
Tim Dettmers committed
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
{

  const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
  const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
  int valid_items = 0;
  float update_scale = 0.0f;

  if(max_unorm > 0.0f)
  {
    update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
    if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; }
    else{ update_scale = 1.0f; }
  }
  else{ update_scale = 1.0f; }

  T g_vals[NUM_PER_THREAD];
  T p_vals[NUM_PER_THREAD];

  float s1_vals[NUM_PER_THREAD];

  typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
  typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;

  typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
  typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;

  __shared__ union {
      typename Load::TempStorage load;
      typename Store::TempStorage store;
      typename LoadFloat::TempStorage loadf;
      typename StoreFloat::TempStorage storef;
  } temp_storage;

  for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
  {
      valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;

      __syncthreads();
      Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
      __syncthreads();
      LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
      __syncthreads();
      Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);

      # pragma unroll 4
      for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
      {
        g_vals[j] = gnorm_scale*((float)g_vals[j]);
        if(weight_decay > 0.0f)
          g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
      }

      # pragma unroll 4
      for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
      {
1219
					if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
1220
1221
1222
					{
						switch(OPTIMIZER)
						{
1223
								case MOMENTUM:
1224
1225
1226
1227
1228
1229
1230
										if(step == 1)
											s1_vals[j] = (float)g_vals[j];
										else
											s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);

										p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
										break;
1231
								case RMSPROP:
1232
1233
1234
										s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
										p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
										break;
1235
								case ADAGRAD:
1236
1237
1238
										s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
										p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
										break;
1239
1240
						}
					}
Tim Dettmers's avatar
Tim Dettmers committed
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
      }

      __syncthreads();
      Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
      __syncthreads();
      StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
  }
}


#define NUM8BIT 16
#define NUM_THREADS 256
#define NUM_PER_BLOCK 4096

template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__  const state1, unsigned char* __restrict__ const state2,
                float *unorm,
                const float beta1, const float beta2,
                const float eps, const int step,
                float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
                float* max1, float* max2, float* new_max1, float* new_max2,
                const float gnorm_scale, const int n)
{
    const int n_full = gridDim.x * NUM_PER_BLOCK;
    const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
    int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
    float g_val = 0.0f;
    float local_max_s1 = -FLT_MAX;
    float local_max_s2 = -FLT_MAX;
    float local_unorm = 0.0f;

    float s2_vals[NUM8BIT];
    float s1_vals[NUM8BIT];
    T g_vals[NUM8BIT];
    unsigned char m_c1[NUM8BIT];
    unsigned char r_c2[NUM8BIT];

    typedef cub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
    typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
    typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;


    __shared__ union {
        typename LoadT::TempStorage loadh;
        typename LoadUInt8::TempStorage loadc;
        typename BlockReduce::TempStorage reduce;
    } temp_storage;

    __shared__ float smem_quantiles1[256];
    __shared__ float smem_quantiles2[256];

    if(threadIdx.x < 256)
    {
        smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
        smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x];
    }

    __syncthreads();

    for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT)
    {
        valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;

        LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
        __syncthreads();
        LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
        __syncthreads();
        LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128);
        __syncthreads();

        #pragma unroll 16
        for(int j = 0; j < NUM8BIT; j++)
        {
            g_val = g_vals[j];
            g_val *= gnorm_scale;
            s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1;
            s1_vals[j] += (1.0f-beta1)*g_val;
            local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
        }

        #pragma unroll 16
        for(int j = 0; j < NUM8BIT; j++)
        {
            g_val = g_vals[j];
            g_val *= gnorm_scale;
            s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2;
            s2_vals[j] += (1.0f-beta2)*g_val*g_val;
            local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j]));
        }

        if(unorm != NULL)
        {
          #pragma unroll 16
          for(int j = 0; j < NUM8BIT; j++)
          {
            float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step));
            float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step));
            s1_vals[j] *= correction1;
            s2_vals[j] *= correction2;
            float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
            local_unorm += update_val*update_val;
          }
        }
    }

    __syncthreads();
    local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
    __syncthreads();
    local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items);
    if(unorm != NULL)
    {
      __syncthreads();
      local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
    }

    if(threadIdx.x == 0)
    {
        atomicMax(&new_max1[0], local_max_s1);
        atomicMax(&new_max2[0], local_max_s2);
        if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); }
    }
}

#define NUM_PER_THREAD2 4
#define NUM_THREADS2 1024
#define NUM_PER_BLOCK2 4096

template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS2, 1)
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
                const float *unorm, const float max_unorm, const float param_norm, \
                const float beta1, const float beta2,
                const float eps, const int step, const float lr,
                float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
                float* max1, float* max2, float* new_max1, float* new_max2,
                float weight_decay,
                const float gnorm_scale, const int n)
{

    const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
    const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
    int valid_items = 0;
    float g_val = 0.0f;
    float s1_vals[NUM_PER_THREAD2];
    float s2_vals[NUM_PER_THREAD2];
    const float correction1 = 1.0f - powf(beta1, step);
    const float correction2 = sqrtf(1.0f - powf(beta2, step));
    const float step_size = -lr*correction2/correction1;
    //const float step_size = -lr*correction2/correction1;
    float new_max_val1 = 1.0f/new_max1[0];
    float new_max_val2 = 1.0f/new_max2[0];
    float update_scale = 1.0f;

    if(max_unorm > 0.0f)
    {
      update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
      if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
      else{ update_scale = 1.0f; }
    }
    else{ update_scale = 1.0f; }

    unsigned char c1s[NUM_PER_THREAD2];
    unsigned char c2s[NUM_PER_THREAD2];
    T p_vals[NUM_PER_THREAD2];
    T g_vals[NUM_PER_THREAD2];
    typedef cub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
    typedef cub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;

    typedef cub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
    typedef cub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;

    __shared__ float smem_quantiles1[256];
    __shared__ float smem_quantiles2[256];

    __shared__ union {
        typename LoadT::TempStorage loadh;
        typename LoadChar::TempStorage loadc;
        typename StoreChar::TempStorage storec;
        typename StoreT::TempStorage storeh;
    } temp_storage;

    if(threadIdx.x < 512)
    {
        if(threadIdx.x < 256)
            smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
        else
            smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256];
    }

    __syncthreads();

    for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
    {
        valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
        LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
        __syncthreads();
        LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
        __syncthreads();
        LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
        __syncthreads();
        LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);

        if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }

        # pragma unroll 4
        for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
        {
            g_val = float(g_vals[j]);
            g_val *= gnorm_scale;
            s1_vals[j] = smem_quantiles1[c1s[j]];
            s1_vals[j] = s1_vals[j]*max1[0];

            s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));

            c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);

            // make sure state1 term has still the same sign after quantization
            // (not needed for state2 term which has only positive values)
            if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
            {
              if(s1_vals[j] > 0.0f)
                  c1s[j] += 1;
              else
                  c1s[j] -= 1;
            }

            s2_vals[j] = smem_quantiles2[c2s[j]];
            s2_vals[j] = s2_vals[j]*max2[0];
            s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
            c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2);
        }

        # pragma unroll 4
        for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
        {
            p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps))))));
            if(weight_decay > 0.0f)
                p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay));
        }

        StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
        __syncthreads();
        StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
        __syncthreads();
        StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
        __syncthreads();
    }
}


template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS, 2)
1497
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__  const state1,
Tim Dettmers's avatar
Tim Dettmers committed
1498
                float *unorm,
1499
                const float beta1,
Tim Dettmers's avatar
Tim Dettmers committed
1500
                const float eps, const int step,
1501
1502
                float* __restrict__ const quantiles1,
                float* max1, float* new_max1,
Tim Dettmers's avatar
Tim Dettmers committed
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
                const float weight_decay,
                const float gnorm_scale, const int n)
{
    const int n_full = gridDim.x * NUM_PER_BLOCK;
    const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
    int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
    float g_val = 0.0f;
    float local_max_s1 = -FLT_MAX;
    float local_unorm = 0.0f;

    float s1_vals[NUM8BIT];
    T g_vals[NUM8BIT];
    unsigned char m_c1[NUM8BIT];

    typedef cub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
    typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
    typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;


    __shared__ union {
        typename LoadT::TempStorage loadh;
        typename LoadUInt8::TempStorage loadc;
        typename BlockReduce::TempStorage reduce;
    } temp_storage;

    __shared__ float smem_quantiles1[256];

    if(threadIdx.x < 256)
      smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];

    __syncthreads();

    for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT)
    {
        valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;

        __syncthreads();
        LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
        __syncthreads();
        LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);

        #pragma unroll 16
        for(int j = 0; j < NUM8BIT; j++)
        {
            g_val = g_vals[j];
            g_val *= gnorm_scale;
            s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0];
            switch(OPTIMIZER)
            {
1552
                case MOMENTUM:
Tim Dettmers's avatar
Tim Dettmers committed
1553
1554
1555
1556
1557
1558
1559
                    if(step == 1)
                      s1_vals[j] = (float)g_vals[j];
                    else
                      s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
                    if(unorm != NULL)
                      local_unorm += s1_vals[j]*s1_vals[j];
                    break;
1560
              case RMSPROP:
Tim Dettmers's avatar
Tim Dettmers committed
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
                    s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
                  break;
            }

            local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
        }
    }

    __syncthreads();
    local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
    if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); }
    if(unorm != NULL)
    {
      __syncthreads();
      local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
      if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); }
    }

}

template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
                const float *unorm, const float max_unorm, const float param_norm,
1585
                const float beta1,
Tim Dettmers's avatar
Tim Dettmers committed
1586
                const float eps, const int step, const float lr,
1587
1588
                float* __restrict__ const quantiles1,
                float* max1, float* new_max1,
Tim Dettmers's avatar
Tim Dettmers committed
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
                float weight_decay,
                const float gnorm_scale, const int n)
{

    const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
    const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
    int valid_items = 0;
    float g_val = 0.0f;
    float s1_vals[NUM_PER_THREAD2];
    float new_max_val1 = 1.0f/new_max1[0];
    float update_scale = 1.0f;

    if(max_unorm > 0.0f)
    {
      update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
      if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
      else{ update_scale = 1.0f; }
    }
    else{ update_scale = 1.0f; }

    unsigned char c1s[NUM_PER_THREAD2];
    T p_vals[NUM_PER_THREAD2];
    T g_vals[NUM_PER_THREAD2];
    typedef cub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
    typedef cub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;

    typedef cub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
    typedef cub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;

    __shared__ float smem_quantiles1[256];

    __shared__ union {
        typename LoadT::TempStorage loadh;
        typename LoadChar::TempStorage loadc;
        typename StoreChar::TempStorage storec;
        typename StoreT::TempStorage storeh;
    } temp_storage;

    if(threadIdx.x < 256)
        smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];

    __syncthreads();

    for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
    {
        valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
        LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
        __syncthreads();
        LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
        __syncthreads();
        LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);

        if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }

        # pragma unroll 4
        for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
        {
            g_val = float(g_vals[j]);
            g_val *= gnorm_scale;
            if(weight_decay > 0.0f)
              g_val += ((float)p_vals[j])*weight_decay;
            s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];

            switch(OPTIMIZER)
            {
1654
                case MOMENTUM:
Tim Dettmers's avatar
Tim Dettmers committed
1655
1656
1657
1658
1659
1660
1661
                  if(step == 1)
                    s1_vals[j] = g_vals[j];
                  else
                    s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);

                  p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
                  break;
1662
              case RMSPROP:
Tim Dettmers's avatar
Tim Dettmers committed
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
                  s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
                  p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
                  break;
            }

            c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);

            // make sure state1 term has still the same sign after quantization
            if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
            {
              if(s1_vals[j] > 0.0f)
                  c1s[j] += 1;
              else
                  c1s[j] -= 1;
            }
        }

        StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
        __syncthreads();
        StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
        __syncthreads();
    }
}


template<typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n)
{
  const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
  int valid_items = 0;

  typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
  typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;

  __shared__ typename BlockReduce::TempStorage reduce;

  __shared__ typename LoadT::TempStorage loadT;
  T vals[NUM_VALS];
  float local_sum = 0.0f;

  for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE)
  {
      valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
      local_sum = 0.0f;

      __syncthreads();
      LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f);

     #pragma unroll NUM_VALS
     for(int j = 0; j < NUM_VALS; j++)
       local_sum += ((float)vals[j])*((float)vals[j]);

    local_sum = BlockReduce(reduce).Sum(local_sum, valid_items);
    if(threadIdx.x == 0)
    {
      if(step == 1)
      {
        // initialize with the same norm for all positions
        //#pragma unroll 10
        for(int j = 0; j < 100; j++)
          atomicAdd(&gnorm_vec[j], local_sum);
      }
      else
          atomicAdd(&gnorm_vec[step % 100], local_sum);
    }

  }
}


#define LANES 2
#define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
                const float beta1, const float beta2,
                const float eps, const int step, const float lr,
                float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
1742
                float* absmax1, float* absmax2,
Tim Dettmers's avatar
Tim Dettmers committed
1743
                float weight_decay,
1744
                const float gnorm_scale, const bool skip_zeros, const int n)
Tim Dettmers's avatar
Tim Dettmers committed
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
{

    //const int n_full = n + (n%BLOCK_SIZE);
    const int n_full = gridDim.x * BLOCK_SIZE;
    const int base_idx = (blockIdx.x * BLOCK_SIZE);
    int valid_items = 0;
    float g_val = 0.0f;
    float s1_vals[N_PER_TH];
    float s2_vals[N_PER_TH];
    // 2-5%
    const float correction1 = 1.0f - __powf(beta1, step);
    const float correction2 = sqrtf(1.0f -__powf(beta2, step));
    const float step_size = __fdividef(-lr*correction2,correction1);
    const int lane_id = threadIdx.x % LANES;
    float new_local_abs_max1 = -FLT_MAX;
    float new_local_abs_max2 = -FLT_MAX;
    float quadrants1[QUAD];
    float quadrants2[QUAD];

    unsigned char c1s[N_PER_TH];
    unsigned char c2s[N_PER_TH];
    T g_vals[N_PER_TH];
1767
    T p_vals[N_PER_TH];
Tim Dettmers's avatar
Tim Dettmers committed
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
    typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
    typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;

    typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
    typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;

    __shared__ float smem_quantiles1[LANES][257];
    __shared__ float smem_quantiles2[LANES][257];
    typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
    typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
    __shared__ typename BlockReduce1::TempStorage reduce1;
    __shared__ typename BlockReduce2::TempStorage reduce2;
    __shared__ float smem_exchange1[1];
    __shared__ float smem_exchange2[1];

    __shared__ union {
        typename LoadT::TempStorage loadh;
        typename LoadChar::TempStorage loadc;
        typename StoreChar::TempStorage storec;
        typename StoreT::TempStorage storeh;
    } temp_storage;
    // init: 0.2 -> 0.23

    // 0.23 -> 0.23
      smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
      smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x];
      # pragma unroll
      for(unsigned int j = 1; j < LANES; j++)
      {
        smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
        smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x];
      }

    __syncthreads();

    #pragma unroll
    for(int k = 0; k < QUAD; k++)
    {
      quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
      quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
    }


    for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
    {
        // loads: 0.23 -> 0.85/1.44
        valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
        __syncthreads();
        LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
        __syncthreads();
        LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
        __syncthreads();
        LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);

        new_local_abs_max1 = -FLT_MAX;
        new_local_abs_max2 = -FLT_MAX;

        //  update: 2.48/1.57 -> 2.51/1.60
        # pragma unroll N_PER_TH
        for(unsigned int j = 0; j < N_PER_TH; j++)
        {
1829
            if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
1830
1831
						{
							s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
1832
1833
1834
1835
1836
              g_val = g_vals[j];
              //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps);
              //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val;
              g_val *= gnorm_scale;
              
1837
							s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
1838
1839
1840

							s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
							s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
1841
						}
1842
1843
1844
1845
1846
            else
            {
              s1_vals[j] = 0.0f;
              s2_vals[j] = 0.0f;
            }
Tim Dettmers's avatar
Tim Dettmers committed
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876

            new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
            new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
        }


        //  reduce: 2.51/1.60 -> 2.67/1.69
        new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
        new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max());

        if(threadIdx.x == 0)
        {
          smem_exchange1[0] = new_local_abs_max1;
          smem_exchange2[0] = new_local_abs_max2;
        }

        __syncthreads();

        if(threadIdx.x == 0)
        {
          absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
          absmax2[i/BLOCK_SIZE] = new_local_abs_max2;
        }
        else
        {
          new_local_abs_max1 = smem_exchange1[0];
          new_local_abs_max2 = smem_exchange2[0];
        }

        __syncthreads();
1877
        LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
Tim Dettmers's avatar
Tim Dettmers committed
1878
1879
1880
1881
        //  reduce: 2.67/1.69 -> 2.67/1.70
        # pragma unroll N_PER_TH
        for(unsigned int j = 0; j < N_PER_TH; j++)
        {
1882
1883
						//if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
            if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
1884
						{
1885
							p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
1886
							if(weight_decay > 0.0f)
1887
									p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
1888
						}
Tim Dettmers's avatar
Tim Dettmers committed
1889
1890
1891
1892
        }

        //  store: 0.85/1.44 -> 2.48/1.57
        __syncthreads();
1893
        StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
Tim Dettmers's avatar
Tim Dettmers committed
1894
1895

        //  quantizaztion: 2.67/1.70  -> 3.4/3.3
1896
        # pragma unroll N_PER_TH
Tim Dettmers's avatar
Tim Dettmers committed
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
        for(unsigned int j = 0; j < N_PER_TH; j++)
        {
            c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
            c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2));

            // make sure state1 term has still the same sign after quantization
            // (not needed for state2 term which has only positive values)
            if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
            {
              if(s1_vals[j] > 0.0f)
                  c1s[j] += 1;
              else
                  c1s[j] -= 1;
            }
        }

        __syncthreads();
        StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
        __syncthreads();
        StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
    }
}


#define LANES 2
#define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
                const float beta1, const float beta2,
                const float eps, const int step, const float lr,
                float* __restrict__ const quantiles1,
                float* absmax1,
                float weight_decay,
1932
                const float gnorm_scale, const bool skip_zeros, const int n)
Tim Dettmers's avatar
Tim Dettmers committed
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
{

    //const int n_full = n + (n%BLOCK_SIZE);
    const int n_full = gridDim.x * BLOCK_SIZE;
    const int base_idx = (blockIdx.x * BLOCK_SIZE);
    int valid_items = 0;
    float g_val = 0.0f;
    float s1_vals[N_PER_TH];
    // 2-5%
    const int lane_id = threadIdx.x % LANES;
    float new_local_abs_max1 = -FLT_MAX;
    float quadrants1[QUAD];

    unsigned char c1s[N_PER_TH];
    T g_vals[N_PER_TH];
		T p_vals[N_PER_TH];

    typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
    typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;

    typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
    typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;

    __shared__ float smem_quantiles1[LANES][257];
    typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
    __shared__ typename BlockReduce1::TempStorage reduce1;
    __shared__ float smem_exchange1[1];

    __shared__ union {
        typename LoadT::TempStorage loadh;
        typename LoadChar::TempStorage loadc;
        typename StoreChar::TempStorage storec;
        typename StoreT::TempStorage storeh;
    } temp_storage;
    // init: 0.2 -> 0.23

    // 0.23 -> 0.23
		smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
		# pragma unroll
		for(unsigned int j = 1; j < LANES; j++)
			smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];

    __syncthreads();

    #pragma unroll
    for(int k = 0; k < QUAD; k++)
      quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];

    for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
    {
        // loads: 0.23 -> 0.85/1.44
        valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
        __syncthreads();
        LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
        __syncthreads();
        LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
        __syncthreads();
        LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);

        new_local_abs_max1 = -FLT_MAX;

        //  update: 2.48/1.57 -> 2.51/1.60
        # pragma unroll N_PER_TH
        for(unsigned int j = 0; j < N_PER_TH; j++)
        {
            g_val = float(g_vals[j]);
            g_val *= gnorm_scale;
2000
						if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
2001
2002
2003
2004
2005
2006
2007
2008
						{
							if(weight_decay > 0.0f)
								g_val += ((float)p_vals[j])*weight_decay;

							s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];

							switch(OPTIMIZER)
							{
2009
									case MOMENTUM:
2010
2011
2012
2013
2014
										if(step == 1)
											s1_vals[j] = g_val;
										else
											s1_vals[j] = (s1_vals[j]*beta1) + g_val;
										break;
2015
									case RMSPROP:
2016
2017
										s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
										break;
2018
									case ADAGRAD:
2019
2020
										s1_vals[j] = s1_vals[j] + (g_val*g_val);
										break;
2021
2022
							}
						}
Tim Dettmers's avatar
Tim Dettmers committed
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044

            new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
        }


        //  reduce: 2.51/1.60 -> 2.67/1.69
        new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());

        if(threadIdx.x == 0)
          smem_exchange1[0] = new_local_abs_max1;

        __syncthreads();

        if(threadIdx.x == 0)
          absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
        else
          new_local_abs_max1 = smem_exchange1[0];

        //  reduce: 2.67/1.69 -> 2.67/1.70
        # pragma unroll N_PER_TH
        for(unsigned int j = 0; j < N_PER_TH; j++)
				{
2045
						if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
2046
2047
2048
						{
							switch(OPTIMIZER)
							{
2049
									case MOMENTUM:
2050
2051
										p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
										break;
2052
									case RMSPROP:
2053
2054
2055
										g_val = g_vals[j];
										p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
										break;
2056
									case ADAGRAD:
2057
2058
2059
										g_val = g_vals[j];
										p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
										break;
2060
2061
							}
						}
Tim Dettmers's avatar
Tim Dettmers committed
2062
2063
2064
2065
2066
2067
2068
				}

        //  store: 0.85/1.44 -> 2.48/1.57
        __syncthreads();
        StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);

        //  quantizaztion: 2.67/1.70  -> 3.4/3.3
2069
        # pragma unroll N_PER_TH
Tim Dettmers's avatar
Tim Dettmers committed
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
        for(unsigned int j = 0; j < N_PER_TH; j++)
        {
            c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));

            // make sure state1 term has still the same sign after quantization
            // (not needed for state2 term which has only positive values)
            if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
            {
              if(s1_vals[j] > 0.0f)
                  c1s[j] += 1;
              else
                  c1s[j] -= 1;
            }
        }

        __syncthreads();
        StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
    }
}

Tim Dettmers's avatar
Tim Dettmers committed
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
{
  // 0. reset stats to -FLT_MAX
  // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
  // 2. compute col max (per thread); store in smem due to register pressure
  // 3. compute row max (per block); store in smem to accumulate full global mem transation
  // 4. store data via atomicMax

  // each block loads TILE_COLs columns and TILE_ROW rows
  // after reading a tile the row counter increase by TILE_ROWS
  // the col counter reset after reading TILE_COL elements
  const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
  // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
  const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
  const int base_idx = (base_row*cols) + base_col;
  const int items_per_load = ITEMS_PER_THREAD*THREADS;

  typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
  typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
  typedef cub::BlockReduce<int, THREADS> BlockRowSum;
  typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;

  __shared__ union {
    typename BlockExchange::TempStorage exchange;
    typename BlockRowReduce::TempStorage rowreduce;
    typename BlockRowSum::TempStorage rowsum;
    typename LoadT::TempStorage loadt;
  } temp_storage;

  __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
  __shared__ int smem_row_nnz_values[TILE_ROWS];

  half local_data[ITEMS_PER_THREAD];
  float local_data_fp32[ITEMS_PER_THREAD];
  float local_col_absmax_values[ITEMS_PER_THREAD];
  int local_row_nnz_count = 0;
  float row_absmax = -FLT_MAX;

  // 0. reset stats to -FLT_MAX
  for(int j = 0; j < ITEMS_PER_THREAD; j++)
  {
    //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
    smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
    smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
  }

  #pragma unroll ITEMS_PER_THREAD
  for(int j = 0; j < ITEMS_PER_THREAD; j++)
    local_col_absmax_values[j] = -FLT_MAX;

  __syncthreads();

  int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
  int i = base_idx;
  // we load row after row from the base_position
  // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
  for(int row = 0; row < TILE_ROWS; row++)
  {
    if(base_row+row >= rows){ break; }
    local_row_nnz_count = 0;
    i = base_idx + ((row)*cols);
    // each thread gets data from the same column
    __syncthreads();
    LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));

    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
      local_data[j] = fabsf(local_data[j]);


    if(SPARSE_DECOMP)
      #pragma unroll ITEMS_PER_THREAD
      for(int j = 0; j < ITEMS_PER_THREAD; j++)
      {
        if((float)local_data[j] >= nnz_threshold)
        {
          local_row_nnz_count += 1;
          local_data[j] = 0.0f;
        }
      }

    // 2. compute col max (per thread); store in smem due to register pressure
    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
      // take the col max for this row
      // we use shared memory because register pressure is too high if we do this locally
      //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
      local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));

    // 3. compute row max (per block); store in smem to accumulate full global mem transation

    // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
      local_data_fp32[j] = local_data[j];

2186
2187
    __syncthreads();

Tim Dettmers's avatar
Tim Dettmers committed
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
    row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
    if(SPARSE_DECOMP)
    {
      __syncthreads();
      local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
    }
    // we store the data temporarily in shared memory so we
    // can execute a full atomic block transaction into global memory later
    // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
    if(threadIdx.x == 0)
    {
      smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
      // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
      smem_row_nnz_values[row] = local_row_nnz_count;
    }

    __syncthreads();

  }

  // 4. store data via atomicMax
  // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
  // into a striped arangement: [0, 8, 16, 24, ..] for t0
  __syncthreads();
  BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);

  #pragma unroll ITEMS_PER_THREAD
  for(int j = 0; j < ITEMS_PER_THREAD; j++)
    if(base_col+threadIdx.x+(j*THREADS) < cols)
    {
      float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
      if(val < local_col_absmax_values[j])
        atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
    }

  for(int j = 0; j < ITEMS_PER_THREAD; j++)
    if(base_row+threadIdx.x+(j*THREADS) < rows)
    {
      float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
      if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
        atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
    }

    if(SPARSE_DECOMP)
      if(threadIdx.x < TILE_ROWS)
        nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];

}

template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);

#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)

2242
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n)
Tim Dettmers's avatar
Tim Dettmers committed
2243
2244
2245
{

  // Strategy: To dequantize we need to load col/row statistics. This can be very expensive
2246
  // since different row/col stats need to be loaded with each thread.
Tim Dettmers's avatar
Tim Dettmers committed
2247
  // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
2248
  // and would lead to low global load utilization.
Tim Dettmers's avatar
Tim Dettmers committed
2249
2250
2251
2252
2253
2254
2255
  // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
  // for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
  // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
  // This allows for efficient row/col loading from shared memory within the tile.
  // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
  // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
  // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
2256
  // shared memory loads.
Tim Dettmers's avatar
Tim Dettmers committed
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303

  // data is in 32 column-tile major with tile width 32 columns and numRows rows
  // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
  // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
  // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
  // C2. Compute normalization values and store col values in register
  // S1. Store C1 into 16-bit output
  // S2. Store col/row statistics of new buffer in shared memory

  // We allow for sub-tiles to span multiple col32 tiles. This is okay
  // since the items per thread only rely on a single column statistic.


  const int n_out = numRows*numCols;

  int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
  // we have tiles of size numRows*32, thus col only increases every numRows
  // num_row_tiles is the tiles after which the column increases by 32
  // blockIdx.x is the index of the current tile
  int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
  // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
  int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);

  // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
  // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
  // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
  // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
  // 1024*1024/(128*32) = 256 tiles
  // 256 tiles are 256*128*32/4 = 256*1024 threads

  // 1. Figure out how index relates to the start of the sub-tile
  // 2. Each thread < SUBTILE_ROWS calculates row index
  // 3. Load striped and store in shared memory

  int local_values[ITEMS_PER_THREAD];
  half local_output[ITEMS_PER_THREAD];
  float local_rowStats[ITEMS_PER_THREAD];
  __shared__ float smem_rowStats[SUBTILE_ROWS];

  typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
  typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
  __shared__ typename LoadInt32::TempStorage loadint32;
  __shared__ typename ExchangeInt32::TempStorage exchangeint32;


  // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
  float colStat = col >= numCols ? 0.0f : colStats[col];
Tim Dettmers's avatar
Tim Dettmers committed
2304
  float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
Tim Dettmers's avatar
Tim Dettmers committed
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
  // no block loads for rows for now -- keep it simple
  for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
  {
    // todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
    int row = (base_row+j) % numRows; // wrap around
    // each warp accesses the same element, for four consequitive elements
    // todo: update description about striped shared memory, it is not needed
    // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
    smem_rowStats[j] = rowStats[row];
  }
  __syncthreads();


  // each block processes SUBTILE_ROWS*32 elements
  const int items_per_load = THREADS*ITEMS_PER_THREAD;
  const int rows_per_load = items_per_load/32;

  int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
  int row_offset = 0;
  // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
  int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);
  for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
  {
    int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
    int valid_items = valid_rows*32;
    if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
      break;

    // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
    LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);
    ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);

    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
      local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];

    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
Tim Dettmers's avatar
Tim Dettmers committed
2343
      local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue);
Tim Dettmers's avatar
Tim Dettmers committed
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
      //absmax_col = fmax(fabsf(local_output[j]), absmax_col);

    // we store data in row major
    // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
    // so that each thread holds ITEMS_PER_THREAD consecutive items for each row
    // this way throughput into storage is increased by a factor of ~2x
    // for now we use a simple store
    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
    {
      int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
      if(outIdx< n_out && col < numCols)
        out[outIdx] = local_output[j];
    }

    row_offset += rows_per_load;
  }
}


template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
{
  // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
  // Each thread reads the same column but multiple rows
  // Rows are loaded in shared memory and access is shared across the threadblock (broadcast)

  // 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
  // 1. Load data row by row (should be at least with TILE_SIZE = 512)
  // 2. quantize data with row/col stats
  // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)

  // each block loads TILE_COLs columns and TILE_ROW rows
  // after reading a tile the row counter increase by TILE_ROWS
  // the col counter reset after reading TILE_COL elements
  const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
  // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
  const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
  const int base_idx = (base_row*cols) + base_col;
  const int items_per_load = ITEMS_PER_THREAD*THREADS;

  typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
  __shared__ typename LoadHalf::TempStorage loadhalf;
  typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
  __shared__ typename StoreInt8::TempStorage storeint8;

  __shared__ float smem_row_stats[TILE_ROWS];
  __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];

  half local_data[ITEMS_PER_THREAD];
  float local_col_stats[ITEMS_PER_THREAD];
  char local_quantized_data[ITEMS_PER_THREAD];

  // 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
  #pragma unroll ITEMS_PER_THREAD
  for(int j = 0; j < ITEMS_PER_THREAD; j++)
    if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
      local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);

  for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
  {
    if(base_row + i < rows)
      smem_row_stats[i] = rowStats[base_row+i];

    if(SPARSE_DECOMP)
      smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
  }
  __syncthreads();

  // we load row after row from the base_position
  // 1. Load data row by row (should be at least with TILE_SIZE = 512)
  for(int row = 0; row < TILE_ROWS; row++)
  {
    if(base_row + row >= rows){ break; }
    int i = base_idx + (row*cols);
    int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;


    LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
    float row_stat = __fdividef(127.0f, smem_row_stats[row]);

    // 2. quantize data with row/col stats
    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
    {
      // we already pre-normalized the col/row stat:
      // what this does is float/absmax*127 = int8
      if(SPARSE_DECOMP)
      {
        if(fabsf((float)local_data[j]) >= threshold)
        {
          local_quantized_data[j] = 0;

					int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);

          rowidx[old_idx] = base_row+row;
          colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
          val[old_idx] = local_data[j];
        }
				else
				{
					local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
				}
      }
      else
        local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
    }

    StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);

    // 2. quantize data with row/col stats
    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
    {
      // we already pre-normalized the col/row stat:
      // what this does is float/absmax*127 = int8
			local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
    }

    __syncthreads();
    StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);

  }
}

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
{

  // 0. Load data into 32*32 shared memory tiles
  // 1. transpose / reorder in shared memory
  // 2. store

  // COL32 FORMAT:
  // rows*32 tiles

  // TURING FORMAT:
  // 8*32 tiles with 4*4 subtiles
  // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
  // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
  // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
  // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
  // index increases by 32

  // AMPERE FORMAT:
  // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
	// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
  // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]


  // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
2493
  // As such we need:
Tim Dettmers's avatar
Tim Dettmers committed
2494
2495
2496
2497
2498
2499
2500
2501
2502
  // at least 32*4 shared memory tiles for col32; preferably 32*32
  // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
  // at least 32*8 shared memory tiles for col4_turing: preferably 32*32
  // for efficient loading of row major we need to load 128 elements and repeat this 32 items
  // this would imply a 32x128 shared memory tile -> 4kb
  // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
  // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
  // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
  // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
2503
  //
Tim Dettmers's avatar
Tim Dettmers committed
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
  // to make the shared memory work with that occupancy we might need to union the block loads/stores

  // each block loads TILE_COLs columns and TILE_ROW rows
  // after reading a tile the row counter increase by TILE_ROWS
  // the col counter reset after reading TILE_COL elements
  const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
  // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
  const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
  const int base_idx = (base_row*cols) + base_col;

  // we load 128 bytes per warp with
  // 32 rows for transposes that fill col32 types
  // so that we can have contiguous stores
  __shared__ char smem_data[32*33*ITEMS_PER_THREAD];
  char local_data[ITEMS_PER_THREAD];
  typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;

  // we load row after row from the base_position
  // Load data row by row
  int warps = blockDim.x/32;
  int warp_id = threadIdx.x/32;
  int warp_lane = threadIdx.x % 32;
  int offset = 0;

  int smem_row = 0;
  // each warp loads one row of 128 bytes
  for(int row = warp_id; row < TILE_ROWS; row+=warps)
  {
    int i = base_idx + (row*cols);
    // we load up to 128 bytes/items per load
    int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;

    // 0. Load data into 32*32 shared memory tiles
    if(base_row + row < rows)
    {
      #pragma unroll ITEMS_PER_THREAD
      for(int j = 0; j < ITEMS_PER_THREAD; j++)
      {
        int col_idx = warp_lane+(j*32);
        if(col_idx < valid_items)
          local_data[j] = A[i+col_idx];
        else
          local_data[j] = 0;
      }
    }
    else
    {
      #pragma unroll ITEMS_PER_THREAD
      for(int j = 0; j < ITEMS_PER_THREAD; j++)
        local_data[j] = 0;
    }

    if(TRANSPOSE)
    {
      #pragma unroll ITEMS_PER_THREAD
      for(int j = 0; j < ITEMS_PER_THREAD; j++)
      {
        int local_col = (32*j)+warp_lane;
        //int local_row = row;
        // store as 256x32
        smem_data[(local_col*33) + row] = local_data[j];
      }
    }
    else
    {
      // treat smem as 32x256, that is 32 rows and 256 columns
      #pragma unroll ITEMS_PER_THREAD
      for(int j = 0; j < ITEMS_PER_THREAD; j++)
        smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
    }



    smem_row += warps;

    // 1. transpose / reorder in shared memory
    if(smem_row % 32 == 0)
    {
      smem_row = 0;
      __syncthreads();

      for(int subrow = warp_id; subrow < 32; subrow+=warps)
      {
        for(int j = 0; j < ITEMS_PER_THREAD; j++)
        {

          switch(FORMAT)
          {
2592
              case COL32:
Tim Dettmers's avatar
Tim Dettmers committed
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
                if(TRANSPOSE)
                {
                  // data lies in shared memory in the following way:
                  // row0 [col0 col1 ... col31]
                  // row1 [col0 col1 ... col31]
                  // ...
                  //
                  // As such we read consequtive entries with 256 threads (8rows x 32 columns)
                  // as j increase, the row increase by a factor of 8
                  // We load 8 rows per subrow loop, and subrow increase by 8 per loop
                  // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
                  const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
                  const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
                  //const int local_row =  warp_id; // each warp_id is one row
                  //const int block_row = base_col; // block offset for row
                  //const int local_col = warp_lane
                  //const int global_col = base_row; // block offset for col
                  if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
                  {
                    // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
                    char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];

                    // each 32 columns we have new tile
                    // each tile has size outRows*32 and base_row is done in increments of 32
2617
                    offset = base_row*outRows;
Tim Dettmers's avatar
Tim Dettmers committed
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
                    out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
                  }
                }
                else
                {
                  if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
                  {
                    offset = (base_col/32)*(32*rows);
                    char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
                    out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
                  }
                }
                break;
              case COL_TURING:
                // TURING FORMAT:
                // 8*32 tiles with 4*4 subtiles
                // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
                // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
                // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
                // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
                // index increases by 32
                //
                // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
                if(TRANSPOSE)
                {
                  const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
                  const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
                  //const int local_row =  warp_id; // each warp_id is one row
                  //const int block_row = base_col; // block offset for row
                  //const int local_col = warp_lane
                  //const int global_col = base_row; // block offset for col
                  if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
                  {
                    // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
                    char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];

                    // each 32 columns we have new tile
                    // each tile has size 8*32 = 256 elements offset
                    // for each row offset of 8 we increaes the tile first
                    // after all rows are exhausted, we increase the col
                    int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows

                    // we increase by row_tile_column every 32 columns
                    // base_row increase in increments of 32
                    //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
2663
                    //int col_offset = (base_row/32)*row_tile_column;
Tim Dettmers's avatar
Tim Dettmers committed
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
                    // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
                    // 256*outRows/8*base_row/32 = outRows*base_row
                    int col_offset = outRows*base_row;

                    offset = row_offset+col_offset;

                    // since we process even number of rows with each j (8) and with each subrow (8j) we can determine
                    // odd or even rows with the warp_id (each warp processes one row)
                    // the col is warp_lane (max 32 columns per row) and the row warp_id
                    if(warp_id % 2 == 1)
                      // odd
                      offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
                    else
                      // even
                      offset += 0   + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);

                    out[offset] = data;
                  }
                }
                else
                {
                  if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
                  {
                    char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
                    // set offset designates the tile offset among the 8*32 tiles
                    // we first increase rows and then columns. Since we load 128 columns at once
                    // we increase the offset by outRows*32 every 32 columns
                    // additionally, we increase the offset by 8*32=256 every 8 rows
                    offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
                    // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
                    // each of these has 32 values in total for 32*4 = 128 as offset if odd
                    // every set of 4 columns increases the total offset by 16
                    // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
                    // this happends every 8 rows anew (subrow % 8)
                    // one writes 4 columns at once that is (col % 4) for the particular index in the subtile
                    int subcol = warp_lane;
2700

Tim Dettmers's avatar
Tim Dettmers committed
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
                    // add local offset (4x4 sub-tile)
                    if(subrow % 2 == 1)
                      // odd
                      offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
                    else
                      // even
                      offset += 0   + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);

                    out[offset] = data;
                  }
                }
                break;
								case COL_AMPERE:
									// AMPERE FORMAT:
									// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
									// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
									// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
									if(TRANSPOSE)
									{
										const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
										const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
										//const int local_row =  warp_id; // each warp_id is one row
										//const int block_row = base_col; // block offset for row
										//const int local_col = warp_lane
										//const int global_col = base_row; // block offset for col
										if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
										{
											// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
											char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];

											// each 32 columns we have new tile
											// each tile has size 32*32 = 1024 elements offset
											// for each row offset of 32 we increaes the tile first
											// after all rows are exhausted, we increase the col
											int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows

											// we increase by row_tile_column every 32 columns
											// base_row increase in increments of 32
											//int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
2740
											//int col_offset = (base_row/32)*row_tile_column;
Tim Dettmers's avatar
Tim Dettmers committed
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
											// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
											// 1024*outRows/32*base_row/32 = outRows*base_row
											int col_offset = outRows*base_row;

											offset = row_offset+col_offset;


											// same as in the non-transpose case (see below)
											// the difference is that now rows = cols
											// in this case warp_id = subrow

											// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
											// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
											// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
											// every 2 rows, the offset increases by two [0, 1, 8, 9...]
											// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
											int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
											int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);

											// global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
											out[offset + (ampere_row*32) + warp_lane] = data;
										}
									}
									else
									{
										if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
										{
											char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];

											// set offset designates the tile offset among the 32*32 tiles
											// we first increase rows and then columns. Since we load 128 columns at once
											// we increase the offset by outRows*32 every 32 columns
											// additionally, we increase the offset by 32*32=1024 every 32 rows
											offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)

											// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
											// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
											// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
											// every 2 rows, the offset increases by two [0, 1, 8, 9...]
											// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
											int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);

											// global offset + row with 32 cols each + 32 cols per j + col_idx
											out[offset + (local_row*32) + warp_lane] = data;
										}
									}
								break;
          }
        }
      }
    }
  }
}

Tim Dettmers's avatar
Tim Dettmers committed
2795
#define DENORM 1.0f/127.0f
Tim Dettmers's avatar
Tim Dettmers committed
2796
2797
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
2798
template <typename T, int SPMM_ITEMS, int BITS>
2799
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
Tim Dettmers's avatar
Tim Dettmers committed
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
{

  // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
  //    If a block finishes, the next one is scheduled. Since the last blocks like have fewer
  //    elements they finish faster "fillin up" the gaps left by larger blocks

  // without tensor cores
  // 1. use rowidx_length to find what to load (as many blocks as there are rows)
  // 2. Load A into registers
  // 3. each warp loads all required rows of B but each warp is offset by k
  // 4. Do mma operations that accumulate into registers
  // 5. Each warp stores its output row into matrix C

  const int count = max_count[blockIdx.x];
  const int local_max_idx = max_idx[blockIdx.x];
  const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
  const int local_row_idx = rowidx[offset];

  const int warp_id = threadIdx.x / 32;
  const int warp_idx = threadIdx.x % 32;
  const int warp_offset = (warp_id*32)*SPMM_ITEMS;
  const int num_items = BITS == 8 ? 8 : 8;
  int idx_col_B = warp_offset;
  int local_idx_col_B_offset = 0;

  half local_valA[MAX_SPARSE_COUNT];
  int local_colidxA[MAX_SPARSE_COUNT];
  half local_valC[SPMM_ITEMS];
  T local_valsB[num_items];
  half local_valOut[num_items];
  // 128 byte loads per warp == 4 bytes per thread

  // 2. Load A into registers
  for(int j = 0; j < MAX_SPARSE_COUNT; j++)
  {
    local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
    local_colidxA[j] = j < count ? colidx[offset+j] : 0;
  }

  // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
  // we expect each warp to be SPMM_ITEMS*32 apart
  // we have a total of 128 bytes for the bank with a bank size of 4 bytes
  // added 3 bytes = 6 values between warps should reduce bank conflicts
  __shared__ half smem_dequant_stats[SMEM_SIZE];


  while(idx_col_B <  colsB)
  {

    if(dequant_stats != NULL)
    {
      for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
        if((idx_col_B+i-local_idx_col_B_offset) < colsB)
2853
          smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
Tim Dettmers's avatar
Tim Dettmers committed
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898

      __syncthreads();
    }

    #pragma unroll SPMM_ITEMS
    for(int j = 0; j < SPMM_ITEMS; j++)
      local_valC[j] = 0.0f;

    #pragma unroll
    for(int i = 0; i < count; i++)
    {
        // 3. each warp loads all required rows of B but each warp is offset by k
        int row_offset = colsB*local_colidxA[i];

        #pragma unroll SPMM_ITEMS
        for(int j = 0; j < SPMM_ITEMS; j+=num_items)
        {
          // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
          int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
          if(idx >= colsB){ break; }
          if((idx+num_items < colsB))
          {
            if(BITS == 8)
              reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
            else
              reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
          }
          else
          {
            #pragma unroll num_items
            for(int k = 0; k < num_items; k++)
              if(idx+k < colsB)
                local_valsB[k] = B[row_offset+idx+k];
              else
                local_valsB[k] = 0.0f;
          }
          #pragma unroll num_items
          for(int k = 0; k < num_items; k++)
          {
            if(BITS == 8 && dequant_stats != NULL)
              // we do texture cache reads (__ldg) on dequant_stats which should be super fast
            {
              float valB = local_valsB[k];
              float valA = local_valA[i];
              if(valB != 0.0 && valA != 0.0)
Tim Dettmers's avatar
Tim Dettmers committed
2899
                local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA;
Tim Dettmers's avatar
Tim Dettmers committed
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
            }
            else
              local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
          }
        }
    }

    int idx_row_C = (colsB*local_row_idx);

    #pragma unroll SPMM_ITEMS
    for(int j = 0; j < SPMM_ITEMS; j+=num_items)
    {
      //int idx_col_C =  idx_col_B + (32*j) + warp_idx;
      int idx_col_C =  idx_col_B + warp_idx*SPMM_ITEMS + j;
      int idx_val = idx_col_C + idx_row_C;

      if(idx_col_C +num_items < colsB)
      {

          // load outputs to do inplace addition
          reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];

          #pragma unroll num_items
          for(int k = 0; k < num_items; k++)
            local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
2925

Tim Dettmers's avatar
Tim Dettmers committed
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
          reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
      }
      else
      {
        #pragma unroll num_items
        for(int k = 0; k < num_items; k++)
         if(idx_col_C + k < colsB)
           out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
      }
    }

    idx_col_B += blockDim.x*SPMM_ITEMS;
    local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
2939
  }
Tim Dettmers's avatar
Tim Dettmers committed
2940
2941
}

2942
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
2943
{
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
	int local_colidx = idx[blockIdx.x];

	if(FORMAT==COL_TURING)
	{
		// TURING FORMAT:
		// 8*32 tiles with 4*4 subtiles
		// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
		// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
		// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
		// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
		// index increases by 32
		// columns are grouped in increments of 4, meaning that one has the following rows and columns
		// rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
		// cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]

		// each thread reads 1 element = 1 row
		for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
		{
			int offset_per_col_tile = ((rowsA+7)/8)*32*8;
			int tile_offset_rows = (row/8)*32*8;
			int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
			int offset = 0;
			int subtile_col_idx = local_colidx%32;
			int subtile_row_idx = row % 8;
			if(row % 2 == 1)
				offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2);
			else
				// even
				offset += 0   + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2);

			offset += tile_offset_rows + tile_offset_cols;

2976
			char val = A[offset];
2977
2978

			int out_idx = (row*idx_size) + blockIdx.x;
2979
			out[out_idx] = val;
2980
		}
2981
2982
2983
	}
	else if(FORMAT == COL_AMPERE)
	{
2984

2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
		for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
		{
			// we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
			// within each tile.
			int offset_per_col_tile = ((rowsA+31)/32)*32*32;
			int tile_offset_rows = (row/32)*32*32;
			int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
			int subtile_col_idx = local_colidx%32;
			int subtile_row_idx = row % 32;
			// this magic is taken from the cublasLt doc (search for COL32)
			int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
			offset += tile_offset_cols + tile_offset_rows;

			char val = A[offset];
			int out_idx = (row*idx_size) + blockIdx.x;
			out[out_idx] = val;
		}
3002
	}
3003
}
3004

Tim Dettmers's avatar
Tim Dettmers committed
3005

Tim Dettmers's avatar
Tim Dettmers committed
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
//template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
//{
//// element-wise kernel
//// 1. Load batch x k into registers
//// 2. Load k x k into registers
//// 3. dequantize and store in second pair of k x k
//// 4. matmul
//// 5. sum with cub
//// 6. store outputs
//// TC kernel
//// use k warps per thread block
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
//// 3. each warp reads a segment of values 16x32 from B 
//// 4. do dequantization from register of B into second pair of registers
//// 5. store (4) into fragment
//// 6. matmul aggregate into fragment C
//// 7. aggreecate files of C into shared memroy block C
//// 8. sum (7)
//// 9. write outputs to matmul output matrix
//}
Tim Dettmers's avatar
Tim Dettmers committed
3027

Tim Dettmers's avatar
Tim Dettmers committed
3028
template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f)
Tim Dettmers's avatar
Tim Dettmers committed
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
{
    if(limit_base + ITEMS <= limit)
      reinterpret_cast<TCAST*>(local)[0] = reinterpret_cast<TCAST*>(buffer)[idx/ITEMS];
    else
    {
      for(int k = 0; k < ITEMS; k++)
      {
        if(limit_base + k < limit)
          local[k] = buffer[idx+k];
        else
Tim Dettmers's avatar
Tim Dettmers committed
3039
          local[k] = (T)zero_value;
Tim Dettmers's avatar
Tim Dettmers committed
3040
3041
3042
3043
      }
    }
}

Tim Dettmers's avatar
Tim Dettmers committed
3044
#define WARPS 5
Tim Dettmers's avatar
Tim Dettmers committed
3045
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A,  T* B,  T * out,  int lda, int ldb, int ldc)
Tim Dettmers's avatar
Tim Dettmers committed
3046
{
Tim Dettmers's avatar
Tim Dettmers committed
3047

Tim Dettmers's avatar
Tim Dettmers committed
3048
3049
3050
3051
3052
3053
3054
  typedef cub::WarpReduce<half> WarpReduce;
  // Allocate WarpReduce shared memory for one warp
  //__shared__ typename WarpReduce::TempStorage temp_storage;

  //typedef cub::BlockReduce<T, THREADS> BlockReduce;
  //// Allocate shared memory for BlockReduce
  //__shared__ typename BlockReduce::TempStorage reduce;
Tim Dettmers's avatar
Tim Dettmers committed
3055
  int col_offset = blockIdx.x *32;
Tim Dettmers's avatar
Tim Dettmers committed
3056
  const int warp_id = threadIdx.x / 32;
Tim Dettmers's avatar
Tim Dettmers committed
3057
3058
  const int half_warp_id = threadIdx.x / 16;
  const int half_warp_lane = threadIdx.x % 16;
Tim Dettmers's avatar
Tim Dettmers committed
3059
  const int batch_size_warps = (WARPS-1)*2;
Tim Dettmers's avatar
Tim Dettmers committed
3060

Tim Dettmers's avatar
Tim Dettmers committed
3061
  T local_A[1];
Tim Dettmers's avatar
Tim Dettmers committed
3062
  T local_B[32];
Tim Dettmers's avatar
Tim Dettmers committed
3063

Tim Dettmers's avatar
Tim Dettmers committed
3064
  const int a_tile_offset = 16;
Tim Dettmers's avatar
Tim Dettmers committed
3065
  const int b_tile_offset = (16*32 + 16);
Tim Dettmers's avatar
Tim Dettmers committed
3066

Tim Dettmers's avatar
Tim Dettmers committed
3067
  __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
Tim Dettmers's avatar
Tim Dettmers committed
3068
  __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
Tim Dettmers's avatar
Tim Dettmers committed
3069
  //__shared__ T smem_C[8*32];
Tim Dettmers's avatar
Tim Dettmers committed
3070

Tim Dettmers's avatar
Tim Dettmers committed
3071
3072
3073
   wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
   wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
   wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
Tim Dettmers's avatar
Tim Dettmers committed
3074
3075
   wmma::fill_fragment(c_frag, 0.0f);

Tim Dettmers's avatar
Tim Dettmers committed
3076
3077
3078
3079
  int ticktock = 0;
  int idx = 0 + threadIdx.x;
  // prefetch
  if(idx < K && warp_id < (WARPS-1))
3080
  {
Tim Dettmers's avatar
Tim Dettmers committed
3081
    local_A[0] = A[idx];
Tim Dettmers's avatar
Tim Dettmers committed
3082

Tim Dettmers's avatar
Tim Dettmers committed
3083
3084
    #pragma unroll 32
    for(int col = 0; col < 32; col++)
Tim Dettmers's avatar
Tim Dettmers committed
3085
      local_B[col] = B[(col_offset+col)*ldb+idx];
Tim Dettmers's avatar
Tim Dettmers committed
3086

Tim Dettmers's avatar
Tim Dettmers committed
3087
    smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
Tim Dettmers's avatar
Tim Dettmers committed
3088

Tim Dettmers's avatar
Tim Dettmers committed
3089
3090
    #pragma unroll 32
    for(int col = 0; col < 32; col++)
Tim Dettmers's avatar
Tim Dettmers committed
3091
        smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
Tim Dettmers's avatar
Tim Dettmers committed
3092
  }
Tim Dettmers's avatar
Tim Dettmers committed
3093
3094
3095
  else if(warp_id < (WARPS-1))
  {
    local_A[0] = T(0.0);
Tim Dettmers's avatar
Tim Dettmers committed
3096
    smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] =  0.0f;
Tim Dettmers's avatar
Tim Dettmers committed
3097
3098
3099

    #pragma unroll 32
    for(int col = 0; col < 32; col++)
Tim Dettmers's avatar
Tim Dettmers committed
3100
      local_B[col] = 0.0f;
Tim Dettmers's avatar
Tim Dettmers committed
3101
3102
3103

    #pragma unroll 32
    for(int col = 0; col < 32; col++)
Tim Dettmers's avatar
Tim Dettmers committed
3104
      smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
Tim Dettmers's avatar
Tim Dettmers committed
3105
  }
Tim Dettmers's avatar
Tim Dettmers committed
3106
  ticktock = ticktock == 0 ? 1 : 0;
Tim Dettmers's avatar
Tim Dettmers committed
3107

3108
  //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
Tim Dettmers's avatar
Tim Dettmers committed
3109
  for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
Tim Dettmers's avatar
Tim Dettmers committed
3110
3111
  {
    idx = base_idx + threadIdx.x;
Tim Dettmers's avatar
Tim Dettmers committed
3112

Tim Dettmers's avatar
Tim Dettmers committed
3113
3114
3115
3116
    __syncthreads();
    if(idx < K && warp_id < (WARPS-1))
    {
      local_A[0] = A[idx];
Tim Dettmers's avatar
Tim Dettmers committed
3117

Tim Dettmers's avatar
Tim Dettmers committed
3118
3119
      #pragma unroll 32
      for(int col = 0; col < 32; col++)
Tim Dettmers's avatar
Tim Dettmers committed
3120
3121
3122
3123
        local_B[col] = B[(col_offset+col)*ldb+idx];

      smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];

Tim Dettmers's avatar
Tim Dettmers committed
3124
3125
      #pragma unroll 32
      for(int col = 0; col < 32; col++)
Tim Dettmers's avatar
Tim Dettmers committed
3126
          smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
Tim Dettmers's avatar
Tim Dettmers committed
3127
    }
Tim Dettmers's avatar
Tim Dettmers committed
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
    else if(warp_id < (WARPS-1))
    {
      local_A[0] = T(0.0);
      smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] =  0.0f;

      #pragma unroll 32
      for(int col = 0; col < 32; col++)
        local_B[col] = 0.0f;

      #pragma unroll 32
      for(int col = 0; col < 32; col++)
        smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
    }
Tim Dettmers's avatar
Tim Dettmers committed
3141
    ticktock = ticktock == 0 ? 1 : 0;
Tim Dettmers's avatar
Tim Dettmers committed
3142
3143
3144
3145
3146
3147
3148
3149

    if(warp_id == (WARPS-1))
      for(int k = 0; k < batch_size_warps; k++)
      {
        wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); //  111 mu
        wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
        wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
      }
Tim Dettmers's avatar
Tim Dettmers committed
3150
  }
Tim Dettmers's avatar
Tim Dettmers committed
3151

Tim Dettmers's avatar
Tim Dettmers committed
3152
  __syncthreads();
Tim Dettmers's avatar
Tim Dettmers committed
3153
3154
3155
3156
  if(warp_id != (WARPS-1)){ return; }
  // only warp_id == (WARPS-1) from here
  int warp_lane = threadIdx.x % 32;

Tim Dettmers's avatar
Tim Dettmers committed
3157
  ticktock = ticktock == 0 ? 1 : 0;
Tim Dettmers's avatar
Tim Dettmers committed
3158
3159
3160
3161
3162
3163
  for(int k = 0; k < batch_size_warps; k++)
  {
    wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); //  111 mu
    wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
    wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
  }
3164

Tim Dettmers's avatar
Tim Dettmers committed
3165
  // 129 mu
Tim Dettmers's avatar
Tim Dettmers committed
3166
  if(warp_id == (WARPS-1))
Tim Dettmers's avatar
Tim Dettmers committed
3167
    wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
3168

Tim Dettmers's avatar
Tim Dettmers committed
3169
3170
  if(col_offset + warp_lane < M)
    out[col_offset + warp_lane] = smem_A[warp_lane];
Tim Dettmers's avatar
Tim Dettmers committed
3171
3172
}

Tim Dettmers's avatar
Tim Dettmers committed
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B,  float *absmax, T * out,  int lda, int ldb, int ldc, int blocksize)
{

  typedef cub::BlockReduce<T, THREADS> BlockReduce;
  __shared__ typename BlockReduce::TempStorage reduce;
  int col_offset = blockIdx.x *8;

  T local_A[32];
  unsigned char local_B_4bit[16];
  T local_B[32];
  T local_C[8];

  __shared__ T smem_C[8];

  if(threadIdx.x < 8)
    smem_C[threadIdx.x] = T(0);
  __syncthreads();

  #pragma unroll 8
  for(int k = 0; k < 8; k++)
    local_C[k] = T(0);


  for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32)
  {

    // we load only 8 values per iteration from A, so we
    // need to do 4 loads for every single load from B
    // for B, we have packed values, so the 16 8-bit values
    // turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads
    vector_load<T, int4, 8>(local_A, A, idx, idx, K);
    vector_load<T, int4, 8>(&(local_A[8]), A, idx+8, idx+8, K);
    vector_load<T, int4, 8>(&(local_A[16]), A, idx+16, idx+16, K);
    vector_load<T, int4, 8>(&(local_A[24]), A, idx+24, idx+24, K);

    for(int col = 0; col < 8; col++)
    {
      if((col + col_offset) >= M){ break; }

      int offset_B = (col_offset+col)*ldb;
      // 0111 -> 0.0f in NF4
      // since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111
      vector_load<unsigned char, int4, 16>(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111);

      int absidx = (idx + offset_B)/blocksize;
      half local_absmax = __ldg(&(absmax[absidx]));
      //for(int k = 0; k < 16; k++)
        //printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F);
      //printf("\n");

      //vector_load<T, int4, 8>(local_A, A, idx, idx, K);

      #pragma unroll 16
      for(int k = 0; k < 16; k++)
      {

        //if(local_B_4bit[k ] != 0b01110111)
          //printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax),
                                         //local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax));
        //local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax;
        //local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax;
        local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax;
        local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax;
        //local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax;
        //local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax;
      }

      #pragma unroll 32
      //for(int k = 0; k < 8; k++)
      for(int k = 0; k < 32; k++)
      {
        local_C[col] += local_A[k]*local_B[k];
        //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
        //if((float)local_B[k] != 0.0)
          //printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]);
      }
    }
  }

  #pragma unroll 8
  for(int k = 0; k < 8; k++)
  {
    local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum());
    __syncthreads();
  }

  if(threadIdx.x == 0)
  {
    #pragma unroll 8
    for(int k = 0; k < 8; k++)
      smem_C[k] = local_C[k];
  }
  else if(threadIdx.x >= 32)
    // early return for unused warps
    return;

  __syncwarp();


  if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
    out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
}

3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
//#define ROWS 2
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A,  T* B,  T * out,  int lda, int ldb, int ldc)
//{
//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp
//// 1. Load dataB into register
//// 2. Dequantize B
//// 3. Fetch data from A and multiply
//
//  typedef cub::BlockLoad<T, THREADS , ITEMS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
//  //__shared__ typename LoadA::TempStorage loada;
//  typedef cub::BlockLoad<T, THREADS , ITEMS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadB;
//  //__shared__ typename LoadB::TempStorage loadb;
//  typedef cub::BlockReduce<T, THREADS> BlockReduce;
//  // Allocate shared memory for BlockReduce
//  //__shared__ typename BlockReduce::TempStorage reduce;
//
//  __shared__ union {
//    typename BlockReduce::TempStorage reduce;
//    typename LoadB::TempStorage loadb;
//    typename LoadA::TempStorage loada;
//  } temp_storage;
//
//
//	T dataA[ITEMS];
//  T local_B[ITEMS];
//  T local_accC[ROWS];
//	int valid_items = 0;
//  const int col_offset = blockIdx.x * 8;
//
//	__shared__ T tileA[ROWS*THREADS*ITEMS];
//	__shared__ T accumulatorC[ROWS*8];
//
//  //#pragma unroll 8
//  //for(int i = 0; i < 8; i++)
//  //  tileA[threadIdx.x + (i*256)] = 0.0f;
//  //__syncthreads();
//  if(threadIdx.x < 64)
//    accumulatorC[threadIdx.x] = 0.0f;
//  __syncthreads();
//
//
//	for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS)
//	{
//		valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx;
//		int baserow = 0;
//		for(int row = baserow; row < (baserow+ROWS) && row < N; row++)
//		{
//			LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f);
//
//      #pragma unroll ITEMS
//      for(int k = 0; k < ITEMS; k++)
//          tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k];
//
//		__syncthreads();
//		}
//		baserow += ROWS;
//
//    // load 16 columns from B at a time. B is transposed, so its like loading rows
//    // each warp loads one row
//    // each thread loads 128 byte
//
//    // col: inner_idx + warp_lane
//    // row: ldb*(offset + warp_id)
//    for(int col = 0; col < 8 && (col_offset + col) < M; col++)
//    {
//      int colB = col_offset + col;
//
//      for(int k = 0; k < ROWS; k++)
//        local_accC[k] = 0.0f;
//
//      int base_idxB = ldb*colB;
//      valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx;
//      LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f);
//      __syncthreads();
//
//      for(int row = 0; row < ROWS && row < N; row++)
//      {
//        #pragma unroll ITEMS
//        for(int k = 0; k < ITEMS; k++)
//        {
//          int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k);
//          local_accC[row] += tileA[idxA]*local_B[k];
//        }
//
//        local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
//        if(threadIdx.x == 0)
//          atomicAdd(&accumulatorC[row*8 + col], local_accC[row]);
//      }
//    }
//	}
//
//  for(int row = 0; row < ROWS && row < N; row++)
//  {
//    int out_idx = ldc*row + col_offset;
//
//    //if(threadIdx.x < 8)
//    //  if(accumulatorC[row*8 + threadIdx.x] != 0.0)
//    //    printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x);
//
//    if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M)
//    {
//      //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx);
//      out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x];
//    }
//  }
//
//
//
//}

Tim Dettmers's avatar
Tim Dettmers committed
3386

Tim Dettmers's avatar
Tim Dettmers committed
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
__device__ void compute(float* global_out, float const* shared_in)
{

}
template <size_t stages_count /* Pipeline with stages_count stages */>
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz) {
    auto grid = cooperative_groups::this_grid();
    auto block = cooperative_groups::this_thread_block();
    assert(size == batch_sz * grid.size()); // Assume input size fits batch_sz * grid_size

    extern __shared__ float shared[]; // stages_count * block.size() * sizeof(int) bytes
    size_t shared_offset[stages_count];
    for (int s = 0; s < stages_count; ++s) shared_offset[s] = s * block.size();

    __shared__ cuda::pipeline_shared_state<
        cuda::thread_scope::thread_scope_block,
        stages_count
    > shared_state;
    auto pipeline = cuda::make_pipeline(block, &shared_state);

    auto block_batch = [&](size_t batch) -> int {
        return block.group_index().x * block.size() + grid.size() * batch;
    };

    // compute_batch: next batch to process
    // fetch_batch:  next batch to fetch from global memory
    for (size_t compute_batch = 0, fetch_batch = 0; compute_batch < batch_sz; ++compute_batch) {
        // The outer loop iterates over the computation of the batches
        for (; fetch_batch < batch_sz && fetch_batch < (compute_batch + stages_count); ++fetch_batch) {
            // This inner loop iterates over the memory transfers, making sure that the pipeline is always full
            pipeline.producer_acquire();
            size_t shared_idx = fetch_batch % stages_count;
            size_t batch_idx = fetch_batch;
            size_t block_batch_idx = block_batch(batch_idx);
            cuda::memcpy_async(block, shared + shared_offset[shared_idx], global_in + block_batch_idx, sizeof(float) * block.size(), pipeline);
            pipeline.producer_commit();
        }
        pipeline.consumer_wait();
        int shared_idx = compute_batch % stages_count;
        int batch_idx = compute_batch;
        compute(global_out + block_batch(batch_idx), shared + shared_offset[shared_idx]);
        pipeline.consumer_release();
    }
}

Tim Dettmers's avatar
Tim Dettmers committed
3432

Tim Dettmers's avatar
Tim Dettmers committed
3433
3434
3435
3436
//==============================================================
//                   TEMPLATE DEFINITIONS
//==============================================================

Tim Dettmers's avatar
Tim Dettmers committed
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
//template <class MShape, class NShape, class KShape,
//          class TA, class AStride, class ABlockLayout, class AThreadLayout,
//          class TB, class BStride, class BBlockLayout, class BThreadLayout,
//          class TC, class CStride, class CBlockLayout, class CThreadLayout,
//          class Alpha, class Beta>
//__global__ static
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
//void
//gemm_device(MShape M, NShape N, KShape K,
//            TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
//            TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
//            TC      * out, CStride dC, CBlockLayout       , CThreadLayout tC,
//            half alpha, half beta);
Tim Dettmers's avatar
Tim Dettmers committed
3450
3451

// these are not used and make no sense, but the compiler needs them
Tim Dettmers's avatar
Tim Dettmers committed
3452
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A,  float* B,  float * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3453
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3454
template __global__ void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3455
template __global__ void gemm_device<half, 32, 160>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3456
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3457
3458
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A,  float* B,  float * out,  int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3459
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3460
template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3461
3462
// these are not used and make no sense, but the compiler needs them

Tim Dettmers's avatar
Tim Dettmers committed
3463
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A,  float* B,  float * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3464
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3465
template __global__ void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3466
template __global__ void gemm_device<half, 16, 160>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
3467
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3468
3469
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A,  float* B,  float * out,  int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3470
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3471
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A,  half* B,  half * out,  int lda, int ldb, int ldc);
Tim Dettmers's avatar
Tim Dettmers committed
3472

Tim Dettmers's avatar
Tim Dettmers committed
3473
3474
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B,  float *absmax, half * out,  int lda, int ldb, int ldc, int blocksize);

Tim Dettmers's avatar
Tim Dettmers committed
3475
3476

//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
Tim Dettmers's avatar
Tim Dettmers committed
3477
template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz);
3478
3479
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
3480

Tim Dettmers's avatar
Tim Dettmers committed
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);

template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);

3495
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
Tim Dettmers's avatar
Tim Dettmers committed
3496
3497
3498
3499

template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);

Tim Dettmers's avatar
Tim Dettmers committed
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
3515
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);

template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n);
template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n);

#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
                float* state1, float *unorm, \
                const float beta1, const float eps, const float weight_decay, \
                const int step, const float lr, const float gnorm_scale, const int n); \

MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
3516
3517
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
Tim Dettmers's avatar
Tim Dettmers committed
3518
3519
3520

#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
3521
    const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
Tim Dettmers's avatar
Tim Dettmers committed
3522
3523
3524
3525
3526

MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
3527
3528
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)
Tim Dettmers's avatar
Tim Dettmers committed
3529
3530
3531
3532
3533
3534
3535
3536

#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p,  \
                float* state1, float* state2, float *unorm, \
                const float beta1, const float beta2, const float eps, const float weight_decay, \
                const int step, const float lr, const float gnorm_scale, const int n); \

MAKE_PreconditionOptimizer32bit2State(ADAM, float)
3537
3538
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16)
Tim Dettmers's avatar
Tim Dettmers committed
3539

3540
3541
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
    const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
Tim Dettmers's avatar
Tim Dettmers committed
3542
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3543
    const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3544
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3545
    const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
Tim Dettmers's avatar
Tim Dettmers committed
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608

#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__  const state1,  \
                float *unorm,  \
                const float beta1,  \
                const float eps, const int step,  \
                float* __restrict__ const quantiles1,  \
                float* max1, float* new_max1,  \
                const float weight_decay, \
                const float gnorm_scale,  \
                const int n); \

MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
MAKE_PreconditionStatic8bit1State(RMSPROP, float)

#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1,  \
                const float *unorm, const float max_unorm, const float param_norm, \
                const float beta1,  \
                const float eps, const int step, const float lr, \
                float* __restrict__ const quantiles1,  \
                float* max1, float* new_max1,  \
                float weight_decay, \
                const float gnorm_scale,  \
                const int n); \

MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
MAKE_optimizerStatic8bit1State(RMSPROP, half)
MAKE_optimizerStatic8bit1State(RMSPROP, float)

#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__  const state1, unsigned char* __restrict__ const state2, \
                float *unorm, \
                const float beta1, const float beta2, \
                const float eps, const int step,  \
                float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
                float* max1, float* max2, float* new_max1, float* new_max2, \
                const float gnorm_scale,  \
                const int n); \

MAKE_PreconditionStatic8bit2State(ADAM, half)
MAKE_PreconditionStatic8bit2State(ADAM, float)

#define MAKE_optimizerStatic8bit2State(oname, gtype) \
template __global__ void kOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \
                const float *unorm, const float max_unorm, const float param_norm, \
                const float beta1, const float beta2, \
                const float eps, const int step, const float lr, \
                float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
                float* max1, float* max2, float* new_max1, float* new_max2, \
                float weight_decay, \
                const float gnorm_scale,  \
                const int n); \

MAKE_optimizerStatic8bit2State(ADAM, half)
MAKE_optimizerStatic8bit2State(ADAM, float)

template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);

Tim Dettmers's avatar
Tim Dettmers committed
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \

MAKE_kQuantizeBlockwise(half,  4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half,  4096, 4, 1, General8bit)
MAKE_kQuantizeBlockwise(half,  2048, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half,  1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half,   512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half,   256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half,   128, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half,    64, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float,  512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float,  256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float,  128, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float,   64, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half,  4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half,  2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half,  1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half,   512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half,   256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half,   128, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half,    64, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float,  512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float,  256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float,  128, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float,   64, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half,  4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half,  2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half,  1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half,   512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half,   256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half,   128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half,    64, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float,  512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float,  256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float,  128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float,   64, 2, 0, NF4)

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
Tim Dettmers's avatar
Tim Dettmers committed
3663
3664
3665
3666
3667
3668
3669
3670
3671


#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
                const float beta1, const float beta2, \
                const float eps, const int step, const float lr, \
                float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
                float* absmax1, float* absmax2,  \
                float weight_decay, \
3672
                const float gnorm_scale, const bool skip_zeros, const int n); \
Tim Dettmers's avatar
Tim Dettmers committed
3673
3674
3675

MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
Tim Dettmers's avatar
Tim Dettmers committed
3676
3677
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8)

Tim Dettmers's avatar
Tim Dettmers committed
3678
3679
3680
3681
3682
3683
3684
3685
3686

#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
		gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
                const float beta1, const float beta2, \
                const float eps, const int step, const float lr, \
                float* __restrict__ const quantiles1, \
                float* absmax1, \
                float weight_decay, \
3687
                const float gnorm_scale, const bool skip_zeros, const int n); \
Tim Dettmers's avatar
Tim Dettmers committed
3688
3689
3690
3691
3692

MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
3693
3694
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)