// Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. #ifndef ops_H #define ops_H #include #include #include #include #include #include #include #include #include #include #include #define CUDA_CHECK_RETURN(value) \ { \ cudaError_t _m_cudaStat = value; \ if (_m_cudaStat != cudaSuccess) { \ fprintf(stderr, "Error %s at line %d in file %s\n", cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ exit(1); \ } \ } #define CHECK_CUSPARSE(value) \ { \ cusparseStatus_t _m_cudaStat = value; \ if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ fprintf( \ stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__ \ ); \ exit(1); \ } \ } inline void checkCudaStatus(cudaError_t status) { if (status != cudaSuccess) { printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); throw std::logic_error("cuda API failed"); } } inline int checkCublasStatus(cublasStatus_t status) { if (status != CUBLAS_STATUS_SUCCESS) { printf("cuBLAS API failed with status %d\n", status); // throw std::logic_error("cuBLAS API failed"); return 1; } return 0; } typedef enum Operations_t { ksmul = 0, } Operations_t; typedef enum Optimizer_t { ADAM = 0, MOMENTUM = 1, RMSPROP = 2, LARS = 3, ADAGRAD = 4, LION = 5, ADEMAMIX = 6 } Optimizer_t; typedef enum Transform_t { ROW = 0, COL = 1, COL32 = 2, COL_TURING = 3, COL_AMPERE = 4, } Transform_t; typedef enum DataType_t { General8bit = 0, FP4 = 1, NF4 = 2, } DataType_t; typedef enum Funcs_t { FILL = 0, ARANGE = 1, _MUL = 2, } Funcs_t; class Context { public: cublasHandle_t m_handle; Context() { cublasHandle_t handle; cublasCreate_v2(&handle); m_handle = handle; } }; class ContextLt { public: cublasLtHandle_t m_handle; ContextLt() { cublasLtHandle_t handle; cublasLtCreate(&handle); m_handle = handle; } }; class ContextCusparse { public: cusparseHandle_t m_handle; ContextCusparse() { cusparseHandle_t handle; cusparseCreate(&handle); m_handle = handle; } }; void quantize(float* code, float* A, unsigned char* out, int n); void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream); template void quantizeBlockwise( float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, cudaStream_t stream ); template void optimizer32bit( T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale, bool skip_zeros, int n ); template void optimizerStatic8bit( T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n ); template void optimizerStatic8bitBlockwise( T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n ); template void percentileClipping(T* g, float* gnorm_vec, int step, const int n); void gemmex( Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc ); void strided_gemmex( Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount ); template int igemmlt( cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream ); void cutlass_igemm( bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc ); void dequant_mm_int32_fp16( int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream ); void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream); void int8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream ); void spmm_coo( cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half* B, int ldc, half* C, bool transposed_B ); template void spmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB ); void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB); template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference( int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize ); template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream ); template void func(T* A, T* B, T value, long n); #endif