"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "cfa80974ddbf9a88d5bd7b6db322e11c876feef8"
pythonInterface.cpp 26.3 KB
Newer Older
1
2
3
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
5
// LICENSE file in the root directory of this source tree.

Max Ryabinin's avatar
Max Ryabinin committed
6
#if BUILD_CUDA
Tim Dettmers's avatar
Tim Dettmers committed
7
#include <ops.cuh>
Max Ryabinin's avatar
Max Ryabinin committed
8
#endif
9
10
11
#if BUILD_MPS
// #include <mps_ops.h>
#endif
Max Ryabinin's avatar
Max Ryabinin committed
12
#include <cpu_ops.h>
Tim Dettmers's avatar
Tim Dettmers committed
13
14

// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
15
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
Tim Dettmers's avatar
Tim Dettmers committed
16
17
18
19
20
// maintain all that boilerplate
//===================================================================================
//                               UNMANGLED CALLS
//===================================================================================

Max Ryabinin's avatar
Max Ryabinin committed
21
#if BUILD_CUDA
Tim Dettmers's avatar
Tim Dettmers committed
22
23
24
25
void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }


Tim Dettmers's avatar
Tim Dettmers committed
26
27
//void gemm_host_fp32(int M, int N, int K, float * A,  float* B,  float * out,  int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
28
void gemm_host_fp16(int M, int N, int K, half * A,  half* B,  half * out,  int lda, int ldb, int ldc)
Tim Dettmers's avatar
Tim Dettmers committed
29
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
Tim Dettmers's avatar
Tim Dettmers committed
30

Tim Dettmers's avatar
Tim Dettmers committed
31
32
33
void gemm_4bit_inference(int m, int n, int k, half * A,  unsigned char* B,  float *absmax, half * out,  int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax,  out, lda, ldb, ldc, blocksize); }

34
35
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A,  unsigned char* B,  float *absmax, float *datatype, half * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax,  datatype, out, lda, ldb, ldc, blocksize, stream); }
36

37
38
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A,  unsigned char* B,  float *absmax, float *datatype, __nv_bfloat16 * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax,  datatype, out, lda, ldb, ldc, blocksize, stream); }
39

40
41
void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A,  unsigned char* B,  float *absmax, float *datatype, float * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax,  datatype, out, lda, ldb, ldc, blocksize, stream); }
42

Tim Dettmers's avatar
Tim Dettmers committed
43
44
45
46
47
48
49
50
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \

MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)

Tim Dettmers's avatar
Tim Dettmers committed
51

Tim Dettmers's avatar
Tim Dettmers committed
52
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
53
void fname##32bit_grad_##gbits(gtype *g, gtype *p, \
Tim Dettmers's avatar
Tim Dettmers committed
54
               float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
55
56
               const float beta1, const float beta2, const float beta3, const float alpha, \
			   const float eps, const float weight_decay, \
57
               const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
58
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
Tim Dettmers's avatar
Tim Dettmers committed
59
60
61

MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
62
63
64
MAKE_FUNC32(adam, ADAM, float, fp32)
MAKE_FUNC32(adam, ADAM, half, fp16)
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
65
66
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
Tim Dettmers's avatar
Tim Dettmers committed
67
68
69
MAKE_FUNC32(lion, LION, float, fp32)
MAKE_FUNC32(lion, LION, half, fp16)
MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16)
70
71
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
72
73
74
75
MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32)
MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16)
MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16)

Tim Dettmers's avatar
Tim Dettmers committed
76
77

#define MAKE_FUNC8(fname, oname, gtype, gbits) \
78
void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
Max Ryabinin's avatar
Max Ryabinin committed
79
								float *unorm, float max_unorm, float param_norm, \
Tim Dettmers's avatar
Tim Dettmers committed
80
81
82
83
84
85
                float beta1, float beta2, \
                float eps, int step, float lr,  \
                float* quantiles1, float* quantiles2, \
                float* max1, float* max2, float* new_max1, float* new_max2, \
                float weight_decay, float gnorm_scale, int n) \
{  \
Max Ryabinin's avatar
Max Ryabinin committed
86
87
	optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
			                                  quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
Tim Dettmers's avatar
Tim Dettmers committed
88
89
90
91
92
93
94
95
} \

MAKE_FUNC8(adam, ADAM, float, 32)
MAKE_FUNC8(adam, ADAM, half, 16)
MAKE_FUNC8(momentum, MOMENTUM, float, 32)
MAKE_FUNC8(momentum, MOMENTUM, half, 16)
MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
96
97
MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16)
Tim Dettmers's avatar
Tim Dettmers committed
98
99

#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
100
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
101
                unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
102
                float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
103
{	optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
Tim Dettmers's avatar
Tim Dettmers committed
104

Tim Dettmers's avatar
Tim Dettmers committed
105
MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
106
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
107
108
MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16)
109
MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
110
111
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
112
MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
113
114
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
115
MAKE_BLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
116
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
117
MAKE_BLOCKWISE8(lion, LION, half, fp16)
Tim Dettmers's avatar
Tim Dettmers committed
118
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
119
MAKE_BLOCKWISE8(lion, LION, float, fp32)
120
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
121
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
122
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
Tim Dettmers's avatar
Tim Dettmers committed
123
124
125
126
127


void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }

Tim Dettmers's avatar
Tim Dettmers committed
128
129
130
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
131
132
133
134
135
136
137

void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }

void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
Tim Dettmers's avatar
Tim Dettmers committed
138
139
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }

140
141
142
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n, stream); } \
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n, stream); } \
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n, stream); } \
143

144
145
146
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n, stream); }
Tim Dettmers's avatar
Tim Dettmers committed
147

148
149
150
void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); }
151

152
153
154
155
156
157
158
159
160
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
    return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
    return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
    return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
Tim Dettmers's avatar
Tim Dettmers committed
161
162
163
164
165
166

void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }

void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
Tim Dettmers's avatar
Tim Dettmers committed
167
#endif
Tim Dettmers's avatar
Tim Dettmers committed
168

Tim Dettmers's avatar
Tim Dettmers committed
169
170
extern "C"
{
Tim Dettmers's avatar
Tim Dettmers committed
171
#if BUILD_CUDA
Max Ryabinin's avatar
Max Ryabinin committed
172
173
174
	void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
	void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
	void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
175
	void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); }
Max Ryabinin's avatar
Max Ryabinin committed
176

177
178
179
  void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); }
  void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); }
  void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); }
Max Ryabinin's avatar
Max Ryabinin committed
180

181
  void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
182
  void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
Tim Dettmers's avatar
Tim Dettmers committed
183
  void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
184
185
186

  void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
  void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
Tim Dettmers's avatar
Tim Dettmers committed
187
  void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
188

189
190
191
  void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); }
  void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); }
  void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); }
192

193
194
195
196
  void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); }
  void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); }
  void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); }

197
198
199
  void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); }
  void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); }
  void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); }
200

Max Ryabinin's avatar
Max Ryabinin committed
201
	#define MAKE_CFUNC32(name, gtype, gbits) \
202
	void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
Max Ryabinin's avatar
Max Ryabinin committed
203
								 float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
204
205
								 const float beta1, const float beta2, const float beta3, const float alpha, \
								 const float eps, const float weight_decay, \
Max Ryabinin's avatar
Max Ryabinin committed
206
								 const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
207
	{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
Max Ryabinin's avatar
Max Ryabinin committed
208

209
210
211
	MAKE_CFUNC32(adam, float, fp32)
	MAKE_CFUNC32(adam, half, fp16)
	MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
Max Ryabinin's avatar
Max Ryabinin committed
212
213
214
215
	MAKE_CFUNC32(momentum, float, 32)
	MAKE_CFUNC32(momentum, half, 16)
	MAKE_CFUNC32(rmsprop, float, 32)
	MAKE_CFUNC32(rmsprop, half, 16)
Tim Dettmers's avatar
Tim Dettmers committed
216
217
218
	MAKE_CFUNC32(lion, float, fp32)
	MAKE_CFUNC32(lion, half, fp16)
	MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
Max Ryabinin's avatar
Max Ryabinin committed
219
220
	MAKE_CFUNC32(adagrad, float, 32)
	MAKE_CFUNC32(adagrad, half, 16)
221
222
223
	MAKE_CFUNC32(ademamix, float, fp32)
	MAKE_CFUNC32(ademamix, half, fp16)
	MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16)
Max Ryabinin's avatar
Max Ryabinin committed
224
225

	#define MAKE_CFUNC8(name, gtype, gbits) \
226
	void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
Tim Dettmers's avatar
Tim Dettmers committed
227
228
229
230
231
                float *unorm, float max_unorm, float param_norm, \
                float beta1, float beta2, \
                float eps, int step, float lr,  \
                float* quantiles1, float* quantiles2, \
                float* max1, float* max2, float* new_max1, float* new_max2, \
232
                float weight_decay, float gnorm_scale, int n) \
Tim Dettmers's avatar
Tim Dettmers committed
233
  {  \
234
	    name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
Max Ryabinin's avatar
Max Ryabinin committed
235
			                                 quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
Tim Dettmers's avatar
Tim Dettmers committed
236
237
  } \

Max Ryabinin's avatar
Max Ryabinin committed
238
239
240
241
242
243
	MAKE_CFUNC8(adam, float, 32)
	MAKE_CFUNC8(adam, half, 16)
	MAKE_CFUNC8(momentum, float, 32)
	MAKE_CFUNC8(momentum, half, 16)
	MAKE_CFUNC8(rmsprop, float, 32)
	MAKE_CFUNC8(rmsprop, half, 16)
244
245
	MAKE_CFUNC8(lion, float, 32)
	MAKE_CFUNC8(lion, half, 16)
Tim Dettmers's avatar
Tim Dettmers committed
246

Max Ryabinin's avatar
Max Ryabinin committed
247
  #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
248
  void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
249
                unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,  \
250
                float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
251
  {	fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
Tim Dettmers's avatar
Tim Dettmers committed
252
253
254

	MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
	MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
255
	MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
256
257
	MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
	MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
258
	MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
259
260
	MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
	MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
261
	MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
262
263
	MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
	MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
264
	MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
265
266
	MAKE_CBLOCKWISE8(lion, LION, half, fp16)
	MAKE_CBLOCKWISE8(lion, LION, float, fp32)
Tim Dettmers's avatar
Tim Dettmers committed
267
	MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
268
269
270
	MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
	MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
	MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
Tim Dettmers's avatar
Tim Dettmers committed
271

Max Ryabinin's avatar
Max Ryabinin committed
272
273
274
	void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
	void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
	void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
Tim Dettmers's avatar
Tim Dettmers committed
275

Tim Dettmers's avatar
Tim Dettmers committed
276
277
278
279
280
281
282
283
284
	void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
	{ gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); }
	void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
			               long strideA, long strideB, long strideC, int batchCount)
	{ strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }

	Context *get_context(){ return new Context(); }
	ContextCusparse *get_cusparse(){ return new ContextCusparse(); }

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
	int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
		return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
	}
	int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
		return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
	}
	int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc,  cudaStream_t stream) {
		return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
	}
	void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream)
	{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); }
	void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
		getRowStats(A, rowStats, threshold, rows, cols, stream);
	}
	void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
		int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
	}
Tim Dettmers's avatar
Tim Dettmers committed
302
303
304
305
306
307
308
309
310
311

	void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
  { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }

	void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
	{ spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }

	void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
	{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }

Tim Dettmers's avatar
Tim Dettmers committed
312
313
	//void cgemm_host_fp32(int M, int N, int K, float * A,  float* B,  float * out,  int lda, int ldb, int ldc)
	//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
Tim Dettmers's avatar
Tim Dettmers committed
314

315
	void cgemm_host_fp16(int M, int N, int K, half * A,  half* B,  half * out,  int lda, int ldb, int ldc)
316
317
	{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }

Tim Dettmers's avatar
Tim Dettmers committed
318
319
320
	void cgemm_4bit_inference(int m, int n, int k, half * A,  unsigned char* B,  float *absmax, half * out,  int lda, int ldb, int ldc, int blocksize)
	{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }

Tim Dettmers's avatar
Tim Dettmers committed
321
322
323
324
325
326
327
328
329
330
331
	void *cget_managed_ptr(size_t bytes)
	{
		void *ptr;
		CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));
		CUDA_CHECK_RETURN(cudaPeekAtLastError());

		return ptr;
	}

	void cprefetch(void *ptr, size_t bytes, int device)
	{
332
333
334
335

		int hasPrefetch = 0;
		CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead
		if (hasPrefetch == 0) return;
336

Tim Dettmers's avatar
Tim Dettmers committed
337
338
339
340
341
342
343
344
345
346
347
348
		CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
		CUDA_CHECK_RETURN(cudaPeekAtLastError());
	}

  #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
	void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \

	CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
	CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
	CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
	CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)

349
350
	void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A,  unsigned char* B,  float *absmax, float *datatype, half * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
	{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax,  datatype, out, lda, ldb, ldc, blocksize, stream); }
351

352
353
	void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A,  unsigned char* B,  float *absmax, float *datatype, __nv_bfloat16 * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
	{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax,  datatype, out, lda, ldb, ldc, blocksize, stream); }
354

355
356
	void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A,  unsigned char* B,  float *absmax, float *datatype, float * out,  int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
	{ gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax,  datatype, out, lda, ldb, ldc, blocksize, stream); }
357

Tim Dettmers's avatar
Tim Dettmers committed
358
#endif
359

360
361
	void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
	void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
Tim Dettmers's avatar
Tim Dettmers committed
362
}