ops.cu 31 KB
Newer Older
1
2
3
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
5
6
7
8
9
10
// LICENSE file in the root directory of this source tree.

#include <ops.cuh>
#include <kernels.cuh>
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
Tim Dettmers's avatar
Tim Dettmers committed
11
#include <cassert>
Max Ryabinin's avatar
Max Ryabinin committed
12
#include <common.h>
Tim Dettmers's avatar
Tim Dettmers committed
13

14
15
#define ERR_NOT_IMPLEMENTED 100

Tim Dettmers's avatar
Tim Dettmers committed
16
17
18
19
20
21

using namespace BinSearch;
using std::cout;
using std::endl;


Max Ryabinin's avatar
Max Ryabinin committed
22
23
void quantize(float *code, float *A, unsigned char *out, int n)
{
24
25
26
  int num_blocks = n/1024;
  num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
  kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
Max Ryabinin's avatar
Max Ryabinin committed
27
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
Tim Dettmers's avatar
Tim Dettmers committed
28
29
}

30
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream)
Max Ryabinin's avatar
Max Ryabinin committed
31
{
32
33
  int num_blocks = n/1024;
  num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
34
  kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n);
Max Ryabinin's avatar
Max Ryabinin committed
35
36
37
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

Tim Dettmers's avatar
Tim Dettmers committed
38
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
Max Ryabinin's avatar
Max Ryabinin committed
39
{
40
41
42
43
  int num_blocks = n/blocksize;
  num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;

  if(blocksize == 4096)
44
    kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
45
  else if(blocksize == 2048)
Tim Dettmers's avatar
Tim Dettmers committed
46
    kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
47
  else if(blocksize == 1024)
Tim Dettmers's avatar
Tim Dettmers committed
48
    kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
49
  else if(blocksize == 512)
Tim Dettmers's avatar
Tim Dettmers committed
50
    kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
51
  else if(blocksize == 256)
Tim Dettmers's avatar
Tim Dettmers committed
52
    kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
53
  else if(blocksize == 128)
Tim Dettmers's avatar
Tim Dettmers committed
54
    kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
55
  else if(blocksize == 64)
Tim Dettmers's avatar
Tim Dettmers committed
56
    kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
57
58


Max Ryabinin's avatar
Max Ryabinin committed
59
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
Tim Dettmers's avatar
Tim Dettmers committed
60
61
}

62
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream)
Max Ryabinin's avatar
Max Ryabinin committed
63
{
64
  // printf("stream==%d\n",stream);
65
66
  int num_blocks = n/blocksize;
  num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
Tim Dettmers's avatar
Tim Dettmers committed
67
68
  int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
  if(DATA_TYPE > 0)
69
    kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n);
70
  else
71
    kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
72

Max Ryabinin's avatar
Max Ryabinin committed
73
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
Tim Dettmers's avatar
Tim Dettmers committed
74
75
}

Tim Dettmers's avatar
Tim Dettmers committed
76
77


Max Ryabinin's avatar
Max Ryabinin committed
78
79
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
                float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
80
                const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
Max Ryabinin's avatar
Max Ryabinin committed
81
82
                const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
83
84
  int num_blocks = n/4096;
  num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
Max Ryabinin's avatar
Max Ryabinin committed
85
86
87
	switch(OPTIMIZER)
	{
		case ADAM:
88
    case ADEMAMIX:
Max Ryabinin's avatar
Max Ryabinin committed
89
90
91
      if(max_unorm > 0.0f)
			{
				CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
92
        kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
Max Ryabinin's avatar
Max Ryabinin committed
93
94
        CUDA_CHECK_RETURN(cudaPeekAtLastError());
      }
95
			kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
Max Ryabinin's avatar
Max Ryabinin committed
96
97
98
99
100
101
102
103
      CUDA_CHECK_RETURN(cudaPeekAtLastError());
			break;
		case MOMENTUM:
    case RMSPROP:
    case ADAGRAD:
      if(max_unorm > 0.0f)
			{
				CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
104
				kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
Max Ryabinin's avatar
Max Ryabinin committed
105
106
107
        CUDA_CHECK_RETURN(cudaPeekAtLastError());
			}

108
			kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
Max Ryabinin's avatar
Max Ryabinin committed
109
110
      CUDA_CHECK_RETURN(cudaPeekAtLastError());
			break;
111
112
    case LION:
      // in lion, the momentum update after the parameter update
113
      kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
114
115
116
117
118
      CUDA_CHECK_RETURN(cudaPeekAtLastError());

      if(max_unorm > 0.0f)
      {
        CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
119
        kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
120
121
122
        CUDA_CHECK_RETURN(cudaPeekAtLastError());
      }
      break;
Max Ryabinin's avatar
Max Ryabinin committed
123
	}
Tim Dettmers's avatar
Tim Dettmers committed
124
125
}

Max Ryabinin's avatar
Max Ryabinin committed
126
127
128
129
130
131
132
133
134
135
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
                unsigned char* state1, unsigned char* state2,
                float *unorm, float max_unorm, float param_norm,
                float beta1, float beta2,
                float eps, int step, float lr,
                float* quantiles1, float* quantiles2,
                float* max1, float* max2, float* new_max1, float* new_max2,
                float weight_decay,
                const float gnorm_scale, int n)
{
136
137
  int num_blocks = n/4096;
  num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
Max Ryabinin's avatar
Max Ryabinin committed
138
139
140
141
142
143
144
145

  if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }

	switch(OPTIMIZER)
	{
		case ADAM:
			CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
			CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
146
			kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
Max Ryabinin's avatar
Max Ryabinin committed
147
			CUDA_CHECK_RETURN(cudaPeekAtLastError());
148
			kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
Max Ryabinin's avatar
Max Ryabinin committed
149
150
151
152
153
154
155
																														quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
			CUDA_CHECK_RETURN(cudaPeekAtLastError());
		break;
		case MOMENTUM:
    case RMSPROP:
    case ADAGRAD:
			CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
156
			kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
Max Ryabinin's avatar
Max Ryabinin committed
157
			CUDA_CHECK_RETURN(cudaPeekAtLastError());
158
			kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
Max Ryabinin's avatar
Max Ryabinin committed
159
160
161
																														quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
			CUDA_CHECK_RETURN(cudaPeekAtLastError());
			break;
162
163
    case LION:
      // in lion, the momentum update happens after the parameter update
164
      kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
165
166
167
168
                                                            quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
      CUDA_CHECK_RETURN(cudaPeekAtLastError());

      CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
169
      kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
170
171
      CUDA_CHECK_RETURN(cudaPeekAtLastError());
      break;
Max Ryabinin's avatar
Max Ryabinin committed
172
173
174
		default:
			break;
	}
Tim Dettmers's avatar
Tim Dettmers committed
175
176
}

177
178
179
180
#define BLOCKSIZE_2STATE 256
#define NUM_2STATE 1
#define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1
Tim Dettmers's avatar
Tim Dettmers committed
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
    T* p,
    T* g,
    unsigned char* state1,
    unsigned char* state2,
    float beta1,
    float beta2,
    float beta3,
    float alpha,
    float eps,
    int step,
    float lr,
    float* quantiles1,
    float* quantiles2,
    float* absmax1,
    float* absmax2,
    float weight_decay,
    const float gnorm_scale,
    bool skip_zeros,
    int n
) {
Max Ryabinin's avatar
Max Ryabinin committed
203

204
	int num_blocks = 0;
Max Ryabinin's avatar
Max Ryabinin committed
205
206
207
	switch(OPTIMIZER)
	{
		case ADAM:
208
    case ADEMAMIX:
209
210
			num_blocks = n/BLOCKSIZE_2STATE;
			num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
211
212
213
214
215
			kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(
				p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
				quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
				skip_zeros, n
			);
Max Ryabinin's avatar
Max Ryabinin committed
216
217
218
219
220
			CUDA_CHECK_RETURN(cudaPeekAtLastError());
		break;
		case MOMENTUM:
		case RMSPROP:
    case ADAGRAD:
221
    case LION:
222
223
224
			num_blocks = n/BLOCKSIZE_1STATE;
			num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
			kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
Max Ryabinin's avatar
Max Ryabinin committed
225
226
227
228
																														quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
			CUDA_CHECK_RETURN(cudaPeekAtLastError());
		break;
	}
Tim Dettmers's avatar
Tim Dettmers committed
229
230
231
}


Max Ryabinin's avatar
Max Ryabinin committed
232
233
234

template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
235
236
  int num_blocks = n/2048;
  num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
Max Ryabinin's avatar
Max Ryabinin committed
237
	CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
238
  kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
Max Ryabinin's avatar
Max Ryabinin committed
239
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
Tim Dettmers's avatar
Tim Dettmers committed
240
241
}

Tim Dettmers's avatar
Tim Dettmers committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{
  const int falpha = 1;
  const int fbeta = 0;
  const void * alpha = &falpha;
  const void * beta = &fbeta;
	cublasStatus_t status;

			status = cublasGemmEx(context->m_handle,
					transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
					transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
					m, n,	k,
					alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
					C, CUDA_R_32I, ldc,
          CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
      std::cout << "CUBLAS ERROR: Status " << status << std::endl;
    }

}

265
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
Tim Dettmers's avatar
Tim Dettmers committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                    long long int strideA, long long int strideB, long long int strideC, int batchCount)
{
  const int falpha = 1;
  const int fbeta = 0;
  const void * alpha = &falpha;
  const void * beta = &fbeta;
	cublasStatus_t status;

  //cout << transposeA << transposeB << endl;
  //printf("%i %i %i\n", m,n,k);
  //printf("%i %i %i\n", lda,ldb,ldc);
  //printf("%i %i %i\n", strideA, strideB, strideC);
  //printf("%i\n", batchCount);

			status = cublasGemmStridedBatchedEx(context->m_handle,
					transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
					transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
					m, n,	k,
					alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
					C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
          CUDA_R_32I, CUBLAS_GEMM_DEFAULT);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
      std::cout << "CUBLAS ERROR: Status " << status << std::endl;
    }

}

int roundoff(int v, int d) {
    return (v + d - 1) / d * d;
}


template<int ORDER> cublasLtOrder_t get_order()
{
	switch(ORDER)
	{
		case ROW:
      return CUBLASLT_ORDER_ROW;
			break;
    case COL:
      return CUBLASLT_ORDER_COL;
      break;
    case COL32:
      return CUBLASLT_ORDER_COL32;
      break;
    case COL_TURING:
      return CUBLASLT_ORDER_COL4_4R2_8C;
      break;
    case COL_AMPERE:
      return CUBLASLT_ORDER_COL32_2R_4R4;
      break;
319
320
		default:
			break;
Tim Dettmers's avatar
Tim Dettmers committed
321
  }
322
323

	return CUBLASLT_ORDER_ROW;
Tim Dettmers's avatar
Tim Dettmers committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
}

template cublasLtOrder_t get_order<ROW>();
template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();


template<int ORDER> int get_leading_dim(int dim1, int dim2)
{
	switch(ORDER)
	{
		case ROW:
      return dim2;
			break;
    case COL:
      return dim1;
      break;
    case COL32:
      // 32*row tiles
      return dim1*32;
      break;
    case COL_TURING:
      return 32*roundoff(dim1, 8);
      break;
    case COL_AMPERE:
      // 32*32 tiles
      return 32*roundoff(dim1, 32);
      break;
354
355
356
		default:
			return 0;
			break;
Tim Dettmers's avatar
Tim Dettmers committed
357
358
359
  }
}

360
361
362
363
364
365
366
367
368
369
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
  cublasLtHandle_t ltHandle,
  int m, int n, int k,
  const int8_t * A,
  const int8_t * B,
  void * C,
  float * row_scale,
  int lda, int ldb, int ldc,
  cudaStream_t stream
) {
Tim Dettmers's avatar
Tim Dettmers committed
370

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
  // Calculate C = A^T @ B, in col-major layout.
  //
  // Use the IMMA kernels requires:
  // * A must be transposed and B must be non-transposed.
  // * Dimensions m and k must be multiples of 4.
  // * All pointers must be 4-byte aligned; 16-byte alignment preferred.

  int has_error = 0;

  cublasLtMatmulDesc_t matmulDesc;
  cublasLtMatrixLayout_t aDesc, bDesc, cDesc;
  cublasOperation_t opT = CUBLAS_OP_T;

  cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I;
  cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F;

  cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;

  has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda));
  has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb));
  has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc));

  // Default layout order is col major

  has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType));
  has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT)));

  if (DTYPE_OUT == 32) {
Tim Dettmers's avatar
Tim Dettmers committed
399
      int alpha = 1, beta = 0;
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
      has_error |= checkCublasStatus(cublasLtMatmul(
        ltHandle, matmulDesc,
        &alpha, A, aDesc,
        B, bDesc, &beta,
        (int32_t*)C, cDesc,
        (int32_t*)C, cDesc,
        NULL, NULL, 0, stream
      ));
  } else {
    // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.

    if (!SCALE_ROWS) {
      float alpha = 1.0f, beta = 0.0f;
      has_error |= checkCublasStatus(cublasLtMatmul(
        ltHandle, matmulDesc,
        &alpha, A, aDesc,
        B, bDesc, &beta,
        (int8_t*)C, cDesc,
        (int8_t*)C, cDesc,
        NULL, NULL, 0, stream
      ));
    } else {
      cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
      float beta = 0.0f;
      has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(
        matmulDesc,
        CUBLASLT_MATMUL_DESC_POINTER_MODE,
        &pointerMode,
        sizeof(alphaVec)
      ));
      has_error |= checkCublasStatus(cublasLtMatmul(
        ltHandle, matmulDesc,
        row_scale, A, aDesc,
        B, bDesc, &beta,
        (int8_t*)C, cDesc,
        (int8_t*)C, cDesc,
        NULL, NULL, 0, stream
      ));
Tim Dettmers's avatar
Tim Dettmers committed
438
    }
439
  }
Tim Dettmers's avatar
Tim Dettmers committed
440

441
442
443
444
  has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc));
  has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc));
  has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc));
  has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
Tim Dettmers's avatar
Tim Dettmers committed
445

446
447
  if(has_error == 1)
    printf("error detected");
Tim Dettmers's avatar
Tim Dettmers committed
448

449
  return has_error;
Tim Dettmers's avatar
Tim Dettmers committed
450
451
452
453
454
455
456
}

int fill_up_to_nearest_multiple(int value, int multiple)
{
  return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}

457
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, cudaStream_t stream)
Tim Dettmers's avatar
Tim Dettmers committed
458
{
459
460
461
462
463
464
465
  const int threads = 512;
  const int num_per_thread = 4;
  const int num_per_block = threads * num_per_thread;
  const int n = numRows*numCols;
  const int num_blocks = (n + num_per_block - 1) / num_per_block;

  kdequant_mm_int32_fp16<num_per_thread, threads><<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
Tim Dettmers's avatar
Tim Dettmers committed
466
467
468
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

469
470
471
472
473
474
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
  if (threshold == 0.0) {
    kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
  } else {
    kInt8VectorQuant<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
  }
Tim Dettmers's avatar
Tim Dettmers committed
475
476
477
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

478
479
480
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
  if (threshold == 0.0)
    kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
Tim Dettmers's avatar
Tim Dettmers committed
481
  else
482
    kgetRowStats<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
Tim Dettmers's avatar
Tim Dettmers committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
    cusparseSpMatDescr_t descA;
    cusparseDnMatDescr_t descB, descC;

    float alpha = 1.0f;
    float beta = 0.0f;
    void *dBuffer = NULL;
    size_t bufferSize = 0;

    CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
                                      A_rowidx, A_colidx, A_vals,
                                      CUSPARSE_INDEX_32I,
                                      CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
    // Create dense matrix C
    CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
                                        CUDA_R_16F, CUSPARSE_ORDER_ROW) );
    // Create dense matrix B
    if(transposed_B)
    {
      int tmp = A_cols;
      A_cols = B_cols;
      B_cols = tmp;
    }

    CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
                                        CUDA_R_16F, CUSPARSE_ORDER_ROW) );
    // allocate an external buffer if needed
    CHECK_CUSPARSE( cusparseSpMM_bufferSize(
                                 handle,
                                 CUSPARSE_OPERATION_NON_TRANSPOSE,
                                 transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
                                 &alpha, descA, descB, &beta, descC, CUDA_R_32F,
                                 CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
    CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );

    // execute SpMM
    CHECK_CUSPARSE( cusparseSpMM(handle,
                                 CUSPARSE_OPERATION_NON_TRANSPOSE,
                                 transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
                                 &alpha, descA, descB, &beta, descC, CUDA_R_32F,
                                 CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));

    // destroy matrix/vector descriptors
    CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
    CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
    CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
    CUDA_CHECK_RETURN( cudaFree(dBuffer) );
}

template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{

  kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
Tim Dettmers's avatar
Tim Dettmers committed
542

Tim Dettmers's avatar
Tim Dettmers committed
543
template <typename T> void gemm_host(int m, int n, int k, T * A,  T* B,  T * out,  int lda, int ldb, int ldc, int bits)
Tim Dettmers's avatar
Tim Dettmers committed
544
{
Tim Dettmers's avatar
Tim Dettmers committed
545

Tim Dettmers's avatar
Tim Dettmers committed
546
	int num_blocks = (m+31)/32;
Tim Dettmers's avatar
Tim Dettmers committed
547

548
549
  if(bits == 32)
    gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m,  n,  k, A,  B,  out, lda, ldb, ldc);
Tim Dettmers's avatar
Tim Dettmers committed
550
  if(bits == 16)
Tim Dettmers's avatar
Tim Dettmers committed
551
    gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m,  n,  k, A,  B,  out, lda, ldb, ldc);
Tim Dettmers's avatar
Tim Dettmers committed
552
553
}

Tim Dettmers's avatar
Tim Dettmers committed
554
555
556
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A,  unsigned char* B,  float *absmax, T * out,  int lda, int ldb, int ldc, int blocksize)
{

557
	int num_blocks = (m+31)/32;
Tim Dettmers's avatar
Tim Dettmers committed
558

559
560
561
  kgemm_4bit_inference<T, 96><<< num_blocks, 96, 0, 0 >>>(m,  n,  k, A,  B, absmax, out, lda, ldb, ldc, blocksize);
}

562
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A,  unsigned char* B,  float *absmax, float *datatype, T * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
563
564
565
{

	int num_blocks = (m+3)/4;
566
  kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, stream>>>(m,  n,  k, A,  B, absmax, datatype, out, lda, ldb, ldc, blocksize);
567
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
Tim Dettmers's avatar
Tim Dettmers committed
568
569
}

Tim Dettmers's avatar
Tim Dettmers committed
570
571
572
573
574
575
576
577
578
579
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
{
  int threads = 512;
  int blocks = n/threads;
  blocks = n % threads == 0 ? blocks : blocks + 1;
  blocks = blocks > 65535 ? 65535 : blocks;
  kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n);
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

Tim Dettmers's avatar
Tim Dettmers committed
580
581
582
583
//==============================================================
//                   TEMPLATE DEFINITIONS
//==============================================================

Tim Dettmers's avatar
Tim Dettmers committed
584
585
586
587
588
template void func<float, FILL>(float *A, float *B, float value, long n);
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
template void func<float, ARANGE>(float *A, float *B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n);

Tim Dettmers's avatar
Tim Dettmers committed
589
template void gemm_4bit_inference<half>(int m, int n, int k, half * A,  unsigned char* B,  float *absmax, half * out,  int lda, int ldb, int ldc, int blocksize);
590
591
592
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A,  unsigned char* B,  float *absmax, float *datatype, half * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A,  unsigned char* B,  float *absmax, float *datatype, __nv_bfloat16 * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A,  unsigned char* B,  float *absmax, float *datatype, float * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
593

Tim Dettmers's avatar
Tim Dettmers committed
594
//template void gemm_host<float>(int m, int n, int k, float * A,  float* B,  float * out,  int lda, int ldb, int ldc, int bits);
Tim Dettmers's avatar
Tim Dettmers committed
595
template void gemm_host<half>(int m, int n, int k, half * A,  half* B,  half * out,  int lda, int ldb, int ldc, int bits);
596

Tim Dettmers's avatar
Tim Dettmers committed
597
598
599
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);

600
601
602
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
Tim Dettmers's avatar
Tim Dettmers committed
603

Tim Dettmers's avatar
Tim Dettmers committed
604
605
606
607
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
608
609
610
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
Tim Dettmers's avatar
Tim Dettmers committed
611
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
612
613
614
615
616
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);

617
618
619
620
621
622
623
624
625
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
Tim Dettmers's avatar
Tim Dettmers committed
626
627
628
629

#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
                float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
630
631
                const float beta1, const float beta2, const float beta3, const float alpha, \
                const float eps, const float weight_decay, \
632
                const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
Tim Dettmers's avatar
Tim Dettmers committed
633
634
635

MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float)
636
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
Tim Dettmers's avatar
Tim Dettmers committed
637
638
MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
639
MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16)
Tim Dettmers's avatar
Tim Dettmers committed
640
641
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
642
MAKE_optimizer32bit(RMSPROP, __nv_bfloat16)
643
644
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
Tim Dettmers's avatar
Tim Dettmers committed
645
MAKE_optimizer32bit(LION, __nv_bfloat16)
646
647
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
648
MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16)
649
650
651
MAKE_optimizer32bit(ADEMAMIX, half)
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, float)
Tim Dettmers's avatar
Tim Dettmers committed
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668

#define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
                float *unorm, float max_unorm, float param_norm, \
                float beta1, float beta2, \
                float eps, int step, float lr,  \
                float* quantiles1, float* quantiles2, \
                float* max1, float* max2, float* new_max1, float* new_max2, \
                float weight_decay, \
                const float gnorm_scale, int n); \

MAKE_optimizerStatic8bit(ADAM, half)
MAKE_optimizerStatic8bit(ADAM, float)
MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float)
669
670
MAKE_optimizerStatic8bit(LION, half)
MAKE_optimizerStatic8bit(LION, float)
671
672
673
MAKE_optimizerStatic8bit(ADAGRAD, half)
MAKE_optimizerStatic8bit(ADAGRAD, float)

Tim Dettmers's avatar
Tim Dettmers committed
674
675
676

#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
677
                unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,  \
678
                float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
Tim Dettmers's avatar
Tim Dettmers committed
679
680
681

MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
682
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
Tim Dettmers's avatar
Tim Dettmers committed
683
684
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
685
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, MOMENTUM);
Tim Dettmers's avatar
Tim Dettmers committed
686
687
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
688
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, RMSPROP);
689
690
MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION);
Tim Dettmers's avatar
Tim Dettmers committed
691
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
692
693
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
694
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAGRAD);
695
696
697
MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
Tim Dettmers's avatar
Tim Dettmers committed
698

Max Ryabinin's avatar
Max Ryabinin committed
699
700
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
Tim Dettmers's avatar
Tim Dettmers committed
701

702
703
704
template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2);