// Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. #if BUILD_CUDA #include #endif #if BUILD_MPS // #include #endif #include // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // maintain all that boilerplate //=================================================================================== // UNMANGLED CALLS //=================================================================================== #if BUILD_CUDA // void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } void gemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } void gemm_4bit_inference( int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize ) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } void gemm_4bit_inference_naive_fp16( int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream ) { gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } void gemm_4bit_inference_naive_bf16( int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream ) { gemm_4bit_inference_naive<__nv_bfloat16, 16>( m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream ); } void gemm_4bit_inference_naive_fp32( int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream ) { gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { func(A, B, value, n); } MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) #define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_grad_##gbits( \ gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \ const float weight_decay, const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n \ ) { \ optimizer32bit( \ g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \ lr, gnorm_scale, skip_zeros, n \ ); \ } MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(adam, ADAM, float, fp32) MAKE_FUNC32(adam, ADAM, half, fp16) MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(lion, LION, float, fp32) MAKE_FUNC32(lion, LION, half, fp16) MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32) MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16) MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16) #define MAKE_FUNC8(fname, oname, gtype, gbits) \ void fname##_static_8bit_grad_##gbits( \ gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \ float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \ float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \ float gnorm_scale, int n \ ) { \ optimizerStatic8bit( \ g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \ max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \ ); \ } MAKE_FUNC8(adam, ADAM, float, 32) MAKE_FUNC8(adam, ADAM, half, 16) MAKE_FUNC8(momentum, MOMENTUM, float, 32) MAKE_FUNC8(momentum, MOMENTUM, half, 16) MAKE_FUNC8(rmsprop, RMSPROP, float, 32) MAKE_FUNC8(rmsprop, RMSPROP, half, 16) MAKE_FUNC8(lion, LION, float, 32) MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ void fname##_8bit_blockwise_grad_##gbits( \ gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \ float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \ float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \ ) { \ optimizerStatic8bitBlockwise( \ p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \ weight_decay, gnorm_scale, skip_zeros, n \ ); \ } MAKE_BLOCKWISE8(adam, ADAM, half, fp16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(adam, ADAM, float, fp32) MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) MAKE_BLOCKWISE8(lion, LION, half, fp16) MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(lion, LION, float, fp32) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32) void percentileClipping_g32(float* g, float* gnorm_vec, int step, const int n) { percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half* g, float* gnorm_vec, int step, const int n) { percentileClipping(g, gnorm_vec, step, n); } void quantizeBlockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_bf16( float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n ) { quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_bf16_fp4( float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n ) { quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_bf16_nf4( float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n ) { quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_fp4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_nf4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void dequantizeBlockwise_fp16( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp16_fp4( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp16_nf4( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32_fp4( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32_nf4( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_fp4( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_nf4( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } int igemmlt_32( cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream ) { return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } int igemmlt_8( cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream ) { return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } int igemmlt_8_rowscale( cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream ) { return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } void spmm_coo_very_sparse_naive_fp16( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB ) { spmm_coo_very_sparse_naive( max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB ); } void spmm_coo_very_sparse_naive_int8( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB ) { spmm_coo_very_sparse_naive( max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB ); } #endif extern "C" { #if BUILD_CUDA void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { dequantize(code, A, out, n, stream); } void cdequantize_blockwise_fp16_fp4( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp16( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp16_nf4( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); } void cquantize_blockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32_fp4( float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n ) { quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32_nf4( float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n ) { quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32_fp4( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32_nf4( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); } void cquantize_blockwise_bf16( float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n ) { quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_fp4( float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n ) { quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_nf4( float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n ) { quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_bf16( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16_fp4( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16_nf4( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream ) { dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_grad_##gbits( \ gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \ const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, \ const int n \ ) { \ name##32bit_grad_##gbits( \ g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \ lr, gnorm_scale, skip_zeros, n \ ); \ } MAKE_CFUNC32(adam, float, fp32) MAKE_CFUNC32(adam, half, fp16) MAKE_CFUNC32(adam, __nv_bfloat16, bf16) MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(lion, float, fp32) MAKE_CFUNC32(lion, half, fp16) MAKE_CFUNC32(lion, __nv_bfloat16, bf16) MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) MAKE_CFUNC32(ademamix, float, fp32) MAKE_CFUNC32(ademamix, half, fp16) MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) #define MAKE_CFUNC8(name, gtype, gbits) \ void c##name##_static_8bit_grad_##gbits( \ gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \ float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \ float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \ float gnorm_scale, int n \ ) { \ name##_static_8bit_grad_##gbits( \ g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \ max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \ ); \ } MAKE_CFUNC8(adam, float, 32) MAKE_CFUNC8(adam, half, 16) MAKE_CFUNC8(momentum, float, 32) MAKE_CFUNC8(momentum, half, 16) MAKE_CFUNC8(rmsprop, float, 32) MAKE_CFUNC8(rmsprop, half, 16) MAKE_CFUNC8(lion, float, 32) MAKE_CFUNC8(lion, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ void c##fname##_8bit_blockwise_grad_##gbits( \ gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \ float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \ float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \ ) { \ fname##_8bit_blockwise_grad_##gbits( \ p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \ weight_decay, gnorm_scale, skip_zeros, n \ ); \ } MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(lion, LION, half, fp16) MAKE_CBLOCKWISE8(lion, LION, float, fp32) MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16) MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32) MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) void cpercentile_clipping_g32(float* g, float* gnorm_vec, int step, const int n) { percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half* g, float* gnorm_vec, int step, const int n) { percentileClipping_g16(g, gnorm_vec, step, n); } void cigemm( Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc ) { gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); } void cbatched_igemm( Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc, long strideA, long strideB, long strideC, int batchCount ) { strided_gemmex( context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount ); } Context* get_context() { return new Context(); } ContextCusparse* get_cusparse() { return new ContextCusparse(); } int cigemmlt_32( Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream ) { return igemmlt_32((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } int cigemmlt_8( Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream ) { return igemmlt_8((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } int cigemmlt_8_rowscale( Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream ) { return igemmlt_8_rowscale((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } void cdequant_mm_int32_fp16( int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream ) { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } void cget_row_stats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) { getRowStats(A, rowStats, threshold, rows, cols, stream); } void cint8_vector_quant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream ) { int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); } void cspmm_coo( ContextCusparse* context, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half* B, int ldc, half* C, bool transposed_B ) { spmm_coo( (cusparseHandle_t)context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B ); } void cspmm_coo_very_sparse_naive_fp16( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB ) { spmm_coo_very_sparse_naive_fp16( max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB ); } void cspmm_coo_very_sparse_naive_int8( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB ) { spmm_coo_very_sparse_naive_int8( max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB ); } // void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } void cgemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } void cgemm_4bit_inference( int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize ) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } void* cget_managed_ptr(size_t bytes) { void* ptr; CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); CUDA_CHECK_RETURN(cudaPeekAtLastError()); return ptr; } void cprefetch(void* ptr, size_t bytes, int device) { int hasPrefetch = 0; CUDA_CHECK_RETURN( cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device) ); // 40ns overhead if (hasPrefetch == 0) return; CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void c##fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { fname##_##type_name(A, B, value, n); } CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) void cgemm_4bit_inference_naive_fp16( int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream ) { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } void cgemm_4bit_inference_naive_bf16( int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream ) { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } void cgemm_4bit_inference_naive_fp32( int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream ) { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #endif void cquantize_blockwise_cpu_fp32( float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n ) { quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n ) { dequantize_cpu(code, A, absmax, out, blocksize, n); } }