Unverified Commit 4955d136 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Apply clang-format rules (#1678)

parent 61db0859
...@@ -26,10 +26,12 @@ void quantize_block(const quantize_block_args& args) { ...@@ -26,10 +26,12 @@ void quantize_block(const quantize_block_args& args) {
if (idx < 255) { if (idx < 255) {
float dist_left = fabs(normed_value - (args.code[idx])); float dist_left = fabs(normed_value - (args.code[idx]));
float dist_right = fabs(normed_value - (args.code[idx + 1])); float dist_right = fabs(normed_value - (args.code[idx + 1]));
if (dist_right < dist_left) { idx += 1; } if (dist_right < dist_left) {
idx += 1;
}
} }
// 5. store index // 5. store index
args.out[i] = (unsigned char) idx; args.out[i] = (unsigned char)idx;
} }
} }
...@@ -2,47 +2,48 @@ ...@@ -2,47 +2,48 @@
// TODO: Let's make some of these constexpr and put in a namespace. // TODO: Let's make some of these constexpr and put in a namespace.
#define BNB_CC_MAXWELL 500 #define BNB_CC_MAXWELL 500
#define BNB_CC_MAXWELL2 520 #define BNB_CC_MAXWELL2 520
#define BNB_CC_MAXWELL2_X1 530 #define BNB_CC_MAXWELL2_X1 530
#define BNB_CC_PASCAL 600 #define BNB_CC_PASCAL 600
#define BNB_CC_PASCAL_X2 620 #define BNB_CC_PASCAL_X2 620
#define BNB_CC_VOLTA 700 #define BNB_CC_VOLTA 700
#define BNB_CC_VOLTA_XAVIER 720 #define BNB_CC_VOLTA_XAVIER 720
#define BNB_CC_TURING 750 #define BNB_CC_TURING 750
#define BNB_CC_AMPERE 800 #define BNB_CC_AMPERE 800
#define BNB_CC_AMPERE2 860 #define BNB_CC_AMPERE2 860
#define BNB_CC_AMPERE2_ORIN 870 #define BNB_CC_AMPERE2_ORIN 870
#define BNB_CC_ADA 890 #define BNB_CC_ADA 890
#define BNB_CC_HOPPER 900 #define BNB_CC_HOPPER 900
#define BNB_CC_BLACKWELL 1000 #define BNB_CC_BLACKWELL 1000
#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1) #define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1)
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) #define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) #define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) #define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) #define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)
#define BNB_WARP_SIZE 32 #define BNB_WARP_SIZE 32
// The maximum number of resident threads per SM varies by arch. // The maximum number of resident threads per SM varies by arch.
// For A100/H100 and all prior to Turing, it is 2048, which allows // For A100/H100 and all prior to Turing, it is 2048, which allows
// for 2 full blocks of 1024 threads per SM. // for 2 full blocks of 1024 threads per SM.
// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability // Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
#if __CUDA_ARCH__ == 750 #if __CUDA_ARCH__ == 750
#define BNB_MAX_THREADS_PER_SM 1024 #define BNB_MAX_THREADS_PER_SM 1024
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 #elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
#define BNB_MAX_THREADS_PER_SM 1536 #define BNB_MAX_THREADS_PER_SM 1536
#else #else
#define BNB_MAX_THREADS_PER_SM 2048 #define BNB_MAX_THREADS_PER_SM 2048
#endif #endif
// Maximum resident warps per SM is always directly related to the number of threads. // Maximum resident warps per SM is always directly related to the number of threads.
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) #define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))
// Maximum resident blocks per SM may vary. // Maximum resident blocks per SM may vary.
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 #if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
#define BNB_MAX_BLOCKS_PER_SM 16 #define BNB_MAX_BLOCKS_PER_SM 16
#else #else
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) #define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)
#endif #endif
...@@ -5,21 +5,18 @@ ...@@ -5,21 +5,18 @@
using namespace BinSearch; using namespace BinSearch;
#define BLOCK_SIZE 16384
struct quantize_block_args { struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher; BinAlgo<Scalar, float, Direct2>* bin_searcher;
float *code; float* code;
float *A; float* A;
float *absmax; float* absmax;
unsigned char *out; unsigned char* out;
long long block_end; long long block_end;
long long block_idx; long long block_idx;
long long threadidx; long long threadidx;
long long blocksize; long long blocksize;
}; };
void quantize_block(const quantize_block_args& args); void quantize_block(const quantize_block_args& args);
#endif #endif
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
using namespace BinSearch; using namespace BinSearch;
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n) {
for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items; long long block_end = block_idx + valid_items;
...@@ -13,8 +13,7 @@ void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, lo ...@@ -13,8 +13,7 @@ void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, lo
} }
} }
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) {
{
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f; code[0] = -1.0f;
...@@ -28,36 +27,35 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long ...@@ -28,36 +27,35 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
int thread_wave_size = 256; int thread_wave_size = 256;
// we chunk the threads into waves of 256 since the max limit is // we chunk the threads into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) {
{ long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; std::vector<std::thread> threads(valid_chunks);
std::vector<std::thread> threads(valid_chunks); std::vector<quantize_block_args> args(valid_chunks);
std::vector<quantize_block_args> args(valid_chunks);
int chunks_processed = 0;
int chunks_processed = 0; for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) {
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
{ long long block_end = block_idx + valid_items;
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items; struct quantize_block_args& arg = args[chunks_processed];
arg.bin_searcher = &bin_searcher;
struct quantize_block_args& arg = args[chunks_processed]; arg.code = code;
arg.bin_searcher = &bin_searcher; arg.A = A;
arg.code = code; arg.absmax = absmax;
arg.A = A; arg.out = out;
arg.absmax = absmax; arg.block_end = block_end;
arg.out = out; arg.block_idx = block_idx;
arg.block_end = block_end; arg.threadidx = block_idx / blocksize;
arg.block_idx = block_idx; arg.blocksize = blocksize;
arg.threadidx = block_idx / blocksize;
arg.blocksize = blocksize; threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
chunks_processed += 1;
threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); if (chunks_processed == valid_chunks) {
chunks_processed += 1; break;
if(chunks_processed == valid_chunks){ break; } }
} }
for (int i = 0; i < valid_chunks; i++) for (int i = 0; i < valid_chunks; i++)
threads[i].join(); threads[i].join();
} }
} }
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <iostream> #include <iostream>
#include <stdio.h> #include <stdio.h>
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n);
#endif #endif
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -9,116 +9,129 @@ ...@@ -9,116 +9,129 @@
#ifndef kernels #ifndef kernels
#define kernels #define kernels
__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n);
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kDequantize(float* code, unsigned char* A, float* out, const int n);
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template <typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); __global__ void kQuantizeBlockwise(
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> );
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha,
const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
const float weight_decay,
const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
float weight_decay, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void __global__ void
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
const float eps, const int step, const float lr, __global__ void kPreconditionOptimizer32bit2State(
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,
float* max1, float* max2, float* new_max1, float* new_max2, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
float weight_decay, const float gnorm_scale, const int n); );
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise( template <typename T, int OPTIMIZER>
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, __global__ void kOptimizer32bit2State(
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise( );
T* p, T* __restrict__ const g, unsigned char* state1,
const float beta1, const float beta2, template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
const float eps, const int step, const float lr, __global__ void kPreconditionOptimizer32bit1State(
float* __restrict__ const quantiles1, T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps,
float* absmax1, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
float weight_decay, );
const float gnorm_scale, const bool skip_zeros, const int n);
template <typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1,
const float beta2, const float eps, const float weight_decay, const int step, const float lr,
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); const float gnorm_scale, const bool skip_zeros, const int n
);
template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, template <typename T, int OPTIMIZER>
half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); __global__ void kPreconditionOptimizerStatic8bit1State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1,
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1,
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); float* new_max1, const float weight_decay, const float gnorm_scale, const int n
);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template <typename T, int OPTIMIZER>
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); __global__ void kOptimizerStatic8bit1State(
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm,
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale,
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n); const int n
);
template <typename T, int OPTIMIZER>
__global__ void kPreconditionOptimizerStatic8bit2State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2,
float* unorm, const float beta1, const float beta2, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizerStatic8bit2State(
T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2,
const float beta3, const float alpha, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps,
const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n);
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template <int ITEMS_PER_THREAD, int THREADS>
__global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);
template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
);
template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
int lda, int ldb, int ldc, int blocksize
);
template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n);
#endif #endif
...@@ -5,63 +5,58 @@ ...@@ -5,63 +5,58 @@
#define NUM 4 #define NUM 4
#define NUM_BLOCK 4096 #define NUM_BLOCK 4096
static inline MPSGraph* get_graph() static inline MPSGraph* get_graph() {
{ static MPSGraph* cur = nil;
static MPSGraph* cur = nil; if (!cur) {
if(!cur) { cur = [[MPSGraph alloc] init];
cur = [[MPSGraph alloc] init]; }
} return cur;
return cur;
} }
static inline id<MTLDevice> get_device() static inline id<MTLDevice> get_device() {
{ NSError* error = nil;
NSError *error = nil; static id<MTLDevice> device = nil;
static id<MTLDevice> device = nil; if (!device) {
if(!device) { device = MTLCreateSystemDefaultDevice();
device = MTLCreateSystemDefaultDevice(); }
} if (!device) {
if(!device) { NSLog(@"Failed to get MPS device");
NSLog(@"Failed to get MPS device"); abort();
abort(); }
} return device;
return device;
} }
static inline id<MTLLibrary> get_library() static inline id<MTLLibrary> get_library() {
{ NSError* error = nil;
NSError *error = nil; static id<MTLLibrary> library = nil;
static id<MTLLibrary> library = nil; if (!library) {
if(!library) { library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error];
library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; }
} if (!library) {
if(!library) { NSLog(@"Failed to load bitsandbytes.metallib");
NSLog(@"Failed to load bitsandbytes.metallib"); abort();
abort(); }
} return library;
return library;
} }
/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) /*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n)
{ {
id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"]; id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0
return out; dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out;
}*/ }*/
// MPSGraph function for quantize // MPSGraph function for quantize
extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) {
{ id<MTLDevice> device = get_device();
id<MTLDevice> device = get_device(); id<MTLLibrary> library = get_library();
id<MTLLibrary> library = get_library(); static id<MTLFunction> kernel = nil;
static id<MTLFunction> kernel = nil; if (!kernel) {
if(!kernel) { kernel = [library newFunctionWithName:@"quantize"];
kernel = [library newFunctionWithName:@"quantize"]; if (!kernel) {
if(!kernel) { NSLog(@"Failed to load bitsandbytes.metallib");
NSLog(@"Failed to load bitsandbytes.metallib"); abort();
abort(); }
} }
} NSLog(@"Not implemented");
NSLog(@"Not implemented"); return nil;
return nil;
} }
...@@ -3,175 +3,195 @@ ...@@ -3,175 +3,195 @@
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#include <ops.cuh>
#include <kernels.cuh>
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h> #include <BinSearch.h>
#include <cassert> #include <cassert>
#include <common.h> #include <common.h>
#include <cub/device/device_scan.cuh>
#include <kernels.cuh>
#include <limits>
#include <ops.cuh>
#define ERR_NOT_IMPLEMENTED 100 #define ERR_NOT_IMPLEMENTED 100
using namespace BinSearch; using namespace BinSearch;
using std::cout; using std::cout;
using std::endl; using std::endl;
void quantize(float* code, float* A, unsigned char* out, int n) {
void quantize(float *code, float *A, unsigned char *out, int n) int num_blocks = n / 1024;
{ num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
int num_blocks = n/1024; kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; CUDA_CHECK_RETURN(cudaPeekAtLastError());
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream) void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
{ int num_blocks = n / 1024;
int num_blocks = n/1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n);
kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) template <typename T, int STOCHASTIC, int DATA_TYPE>
{ void quantizeBlockwise(
int num_blocks = n/blocksize; float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; ) {
int num_blocks = n / blocksize;
if(blocksize == 4096) num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048) if (blocksize == 4096)
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE>
else if(blocksize == 1024) <<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 2048)
else if(blocksize == 512) kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 1024)
else if(blocksize == 256) kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 512)
else if(blocksize == 128) kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 256)
else if(blocksize == 64) kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 64)
CUDA_CHECK_RETURN(cudaPeekAtLastError()); kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream) template <typename T, int DATA_TYPE>
{ void dequantizeBlockwise(
// printf("stream==%d\n",stream); float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, cudaStream_t stream
int num_blocks = n/blocksize; ) {
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; // printf("stream==%d\n",stream);
int tile_size = (DATA_TYPE > 0) ? 1024 : 512; int num_blocks = n / blocksize;
if(DATA_TYPE > 0) num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n); int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
else if (DATA_TYPE > 0)
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n); kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); else
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int OPTIMIZER>
void optimizer32bit(
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, const float beta1,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const float lr, const float gnorm_scale, bool skip_zeros, const int n
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) ) {
{ int num_blocks = n / 4096;
int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; switch (OPTIMIZER) {
switch(OPTIMIZER) case ADAM:
{
case ADAM:
case ADEMAMIX: case ADEMAMIX:
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr,
gnorm_scale, skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} break;
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); case MOMENTUM:
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8>
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); <<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale,
skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} break;
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION: case LION:
// in lion, the momentum update after the parameter update // in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
CUDA_CHECK_RETURN(cudaPeekAtLastError()); g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale,
skip_zeros, n
if(max_unorm > 0.0f) );
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
break; if (max_unorm > 0.0f) {
} CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8>
<<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
break;
}
} }
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, template <typename T, int OPTIMIZER>
unsigned char* state1, unsigned char* state2, void optimizerStatic8bit(
float *unorm, float max_unorm, float param_norm, T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm,
float beta1, float beta2, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1,
float eps, int step, float lr, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n
float* quantiles1, float* quantiles2, ) {
float* max1, float* max2, float* new_max1, float* new_max2, int num_blocks = n / 4096;
float weight_decay, num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
const float gnorm_scale, int n)
{ if (max_unorm > 0.0f) {
int num_blocks = n/4096; CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; }
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } switch (OPTIMIZER) {
case ADAM:
switch(OPTIMIZER) CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
{ CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float)));
case ADAM: kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1,
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); new_max2, gnorm_scale, n
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); );
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n
break; );
case MOMENTUM: CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(
CUDA_CHECK_RETURN(cudaPeekAtLastError()); p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, );
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaPeekAtLastError()); kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
break; p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1,
weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION: case LION:
// in lion, the momentum update happens after the parameter update // in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
break; kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(
default: p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n
break; );
} CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
default:
break;
}
} }
#define BLOCKSIZE_2STATE 256 #define BLOCKSIZE_2STATE 256
...@@ -179,148 +199,120 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -179,148 +199,120 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
#define BLOCKSIZE_1STATE 256 #define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1 #define NUM_1STATE 1
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise( template <typename T, int OPTIMIZER>
T* p, void optimizerStatic8bitBlockwise(
T* g, T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
unsigned char* state1, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2,
unsigned char* state2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n
float beta1,
float beta2,
float beta3,
float alpha,
float eps,
int step,
float lr,
float* quantiles1,
float* quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
bool skip_zeros,
int n
) { ) {
int num_blocks = 0; int num_blocks = 0;
switch(OPTIMIZER) switch (OPTIMIZER) {
{ case ADAM:
case ADAM:
case ADEMAMIX: case ADEMAMIX:
num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n / BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>( kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE>
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, <<<num_blocks, BLOCKSIZE_2STATE / NUM_2STATE>>>(
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1,
skip_zeros, n absmax2, weight_decay, gnorm_scale, skip_zeros, n
); );
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
case LION: case LION:
num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n / BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE>
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); <<<num_blocks, BLOCKSIZE_1STATE / NUM_1STATE>>>(
CUDA_CHECK_RETURN(cudaPeekAtLastError()); p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n
break; );
} CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
}
} }
template <typename T> void percentileClipping(T* g, float* gnorm_vec, int step, const int n) {
int num_blocks = n / 2048;
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n) num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
{ CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float)));
int num_blocks = n/2048; kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) void gemmex(
{ Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
const int falpha = 1; int ldb, int ldc
const int fbeta = 0; ) {
const void * alpha = &falpha; const int falpha = 1;
const void * beta = &fbeta; const int fbeta = 0;
cublasStatus_t status; const void* alpha = &falpha;
const void* beta = &fbeta;
status = cublasGemmEx(context->m_handle, cublasStatus_t status;
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, status = cublasGemmEx(
m, n, k, context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k,
alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, C, CUDA_R_32I, ldc, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP
C, CUDA_R_32I, ldc, );
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
if (status != CUBLAS_STATUS_SUCCESS) {
if (status != CUBLAS_STATUS_SUCCESS) std::cout << "CUBLAS ERROR: Status " << status << std::endl;
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
} }
} }
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, void strided_gemmex(
long long int strideA, long long int strideB, long long int strideC, int batchCount) Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
{ int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount
const int falpha = 1; ) {
const int fbeta = 0; const int falpha = 1;
const void * alpha = &falpha; const int fbeta = 0;
const void * beta = &fbeta; const void* alpha = &falpha;
cublasStatus_t status; const void* beta = &fbeta;
cublasStatus_t status;
//cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k); // cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", lda,ldb,ldc); // printf("%i %i %i\n", m,n,k);
//printf("%i %i %i\n", strideA, strideB, strideC); // printf("%i %i %i\n", lda,ldb,ldc);
//printf("%i\n", batchCount); // printf("%i %i %i\n", strideA, strideB, strideC);
// printf("%i\n", batchCount);
status = cublasGemmStridedBatchedEx(context->m_handle,
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, status = cublasGemmStridedBatchedEx(
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k,
m, n, k, alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, C,
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, CUDA_R_32I, ldc, (long long int)strideC, batchCount, CUDA_R_32I, CUBLAS_GEMM_DEFAULT
C, CUDA_R_32I, ldc, (long long int)strideC, batchCount, );
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
if (status != CUBLAS_STATUS_SUCCESS) {
if (status != CUBLAS_STATUS_SUCCESS) std::cout << "CUBLAS ERROR: Status " << status << std::endl;
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
} }
}
int roundoff(int v, int d) {
return (v + d - 1) / d * d;
} }
int roundoff(int v, int d) { return (v + d - 1) / d * d; }
template<int ORDER> cublasLtOrder_t get_order() template <int ORDER> cublasLtOrder_t get_order() {
{ switch (ORDER) {
switch(ORDER) case ROW:
{ return CUBLASLT_ORDER_ROW;
case ROW: break;
return CUBLASLT_ORDER_ROW;
break;
case COL: case COL:
return CUBLASLT_ORDER_COL; return CUBLASLT_ORDER_COL;
break; break;
case COL32: case COL32:
return CUBLASLT_ORDER_COL32; return CUBLASLT_ORDER_COL32;
break; break;
case COL_TURING: case COL_TURING:
return CUBLASLT_ORDER_COL4_4R2_8C; return CUBLASLT_ORDER_COL4_4R2_8C;
break; break;
case COL_AMPERE: case COL_AMPERE:
return CUBLASLT_ORDER_COL32_2R_4R4; return CUBLASLT_ORDER_COL32_2R_4R4;
break; break;
default: default:
break; break;
} }
return CUBLASLT_ORDER_ROW; return CUBLASLT_ORDER_ROW;
} }
template cublasLtOrder_t get_order<ROW>(); template cublasLtOrder_t get_order<ROW>();
...@@ -329,355 +321,394 @@ template cublasLtOrder_t get_order<COL32>(); ...@@ -329,355 +321,394 @@ template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>(); template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>(); template cublasLtOrder_t get_order<COL_AMPERE>();
template <int ORDER> int get_leading_dim(int dim1, int dim2) {
template<int ORDER> int get_leading_dim(int dim1, int dim2) switch (ORDER) {
{ case ROW:
switch(ORDER) return dim2;
{ break;
case ROW:
return dim2;
break;
case COL: case COL:
return dim1; return dim1;
break; break;
case COL32: case COL32:
// 32*row tiles // 32*row tiles
return dim1*32; return dim1 * 32;
break; break;
case COL_TURING: case COL_TURING:
return 32*roundoff(dim1, 8); return 32 * roundoff(dim1, 8);
break; break;
case COL_AMPERE: case COL_AMPERE:
// 32*32 tiles // 32*32 tiles
return 32*roundoff(dim1, 32); return 32 * roundoff(dim1, 32);
break; break;
default: default:
return 0; return 0;
break; break;
} }
} }
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt( template <int DTYPE_OUT, int SCALE_ROWS>
cublasLtHandle_t ltHandle, int igemmlt(
int m, int n, int k, cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
const int8_t * A, int lda, int ldb, int ldc, cudaStream_t stream
const int8_t * B,
void * C,
float * row_scale,
int lda, int ldb, int ldc,
cudaStream_t stream
) { ) {
// Calculate C = A^T @ B, in col-major layout. // Calculate C = A^T @ B, in col-major layout.
// //
// Use the IMMA kernels requires: // Use the IMMA kernels requires:
// * A must be transposed and B must be non-transposed. // * A must be transposed and B must be non-transposed.
// * Dimensions m and k must be multiples of 4. // * Dimensions m and k must be multiples of 4.
// * All pointers must be 4-byte aligned; 16-byte alignment preferred. // * All pointers must be 4-byte aligned; 16-byte alignment preferred.
int has_error = 0; int has_error = 0;
cublasLtMatmulDesc_t matmulDesc; cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t aDesc, bDesc, cDesc; cublasLtMatrixLayout_t aDesc, bDesc, cDesc;
cublasOperation_t opT = CUBLAS_OP_T; cublasOperation_t opT = CUBLAS_OP_T;
cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I; cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I;
cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F; cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F;
cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda)); has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb)); has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc));
// Default layout order is col major // Default layout order is col major
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType)); has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); has_error |=
checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT)));
if (DTYPE_OUT == 32) {
int alpha = 1, beta = 0; if (DTYPE_OUT == 32) {
has_error |= checkCublasStatus(cublasLtMatmul( int alpha = 1, beta = 0;
ltHandle, matmulDesc, has_error |= checkCublasStatus(cublasLtMatmul(
&alpha, A, aDesc, ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, NULL, NULL,
B, bDesc, &beta, 0, stream
(int32_t*)C, cDesc, ));
(int32_t*)C, cDesc,
NULL, NULL, 0, stream
));
} else {
// This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.
if (!SCALE_ROWS) {
float alpha = 1.0f, beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc,
&alpha, A, aDesc,
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
));
} else { } else {
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.
float beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute( if (!SCALE_ROWS) {
matmulDesc, float alpha = 1.0f, beta = 0.0f;
CUBLASLT_MATMUL_DESC_POINTER_MODE, has_error |= checkCublasStatus(cublasLtMatmul(
&pointerMode, ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL,
sizeof(alphaVec) NULL, 0, stream
)); ));
has_error |= checkCublasStatus(cublasLtMatmul( } else {
ltHandle, matmulDesc, cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
row_scale, A, aDesc, float beta = 0.0f;
B, bDesc, &beta, has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(
(int8_t*)C, cDesc, matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, sizeof(alphaVec)
(int8_t*)C, cDesc, ));
NULL, NULL, 0, stream has_error |= checkCublasStatus(cublasLtMatmul(
)); ltHandle, matmulDesc, row_scale, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL,
NULL, 0, stream
));
}
} }
}
has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc)); has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc)); has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc)); has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc));
has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
if(has_error == 1) if (has_error == 1)
printf("error detected"); printf("error detected");
return has_error; return has_error;
} }
int fill_up_to_nearest_multiple(int value, int multiple) int fill_up_to_nearest_multiple(int value, int multiple) {
{ return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
} }
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, cudaStream_t stream) void dequant_mm_int32_fp16(
{ int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
const int threads = 512; ) {
const int num_per_thread = 4; const int threads = 512;
const int num_per_block = threads * num_per_thread; const int num_per_thread = 4;
const int n = numRows*numCols; const int num_per_block = threads * num_per_thread;
const int num_blocks = (n + num_per_block - 1) / num_per_block; const int n = numRows * numCols;
const int num_blocks = (n + num_per_block - 1) / num_per_block;
kdequant_mm_int32_fp16<num_per_thread, threads><<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); kdequant_mm_int32_fp16<num_per_thread, threads>
<<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { void int8VectorQuant(
if (threshold == 0.0) { half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols); ) {
} else { if (threshold == 0.0) {
kInt8VectorQuant<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols); kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
} } else {
CUDA_CHECK_RETURN(cudaPeekAtLastError()); kInt8VectorQuant<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
}
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
if (threshold == 0.0) if (threshold == 0.0)
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols); kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
else else
kgetRowStats<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols); kgetRowStats<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) void spmm_coo(
{ cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
) {
cusparseSpMatDescr_t descA; cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC; cusparseDnMatDescr_t descB, descC;
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
void *dBuffer = NULL; void* dBuffer = NULL;
size_t bufferSize = 0; size_t bufferSize = 0;
CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, CHECK_CUSPARSE(cusparseCreateCoo(
A_rowidx, A_colidx, A_vals, &descA, A_rows, A_cols, A_nnz, A_rowidx, A_colidx, A_vals, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_INDEX_32I, CUDA_R_16F
CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) ); ));
// Create dense matrix C // Create dense matrix C
CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, CHECK_CUSPARSE(cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, CUDA_R_16F, CUSPARSE_ORDER_ROW));
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
// Create dense matrix B // Create dense matrix B
if(transposed_B) if (transposed_B) {
{ int tmp = A_cols;
int tmp = A_cols; A_cols = B_cols;
A_cols = B_cols; B_cols = tmp;
B_cols = tmp;
} }
CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, CHECK_CUSPARSE(cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, CUDA_R_16F, CUSPARSE_ORDER_ROW));
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
// allocate an external buffer if needed // allocate an external buffer if needed
CHECK_CUSPARSE( cusparseSpMM_bufferSize( CHECK_CUSPARSE(cusparseSpMM_bufferSize(
handle, handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize
&alpha, descA, descB, &beta, descC, CUDA_R_32F, ));
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); CUDA_CHECK_RETURN(cudaMalloc(&dBuffer, bufferSize));
CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
// execute SpMM // execute SpMM
CHECK_CUSPARSE( cusparseSpMM(handle, CHECK_CUSPARSE(cusparseSpMM(
CUSPARSE_OPERATION_NON_TRANSPOSE, handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta,
&alpha, descA, descB, &beta, descC, CUDA_R_32F, descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, dBuffer
CUSPARSE_SPMM_ALG_DEFAULT, dBuffer)); ));
// destroy matrix/vector descriptors // destroy matrix/vector descriptors
CHECK_CUSPARSE( cusparseDestroySpMat(descA) ); CHECK_CUSPARSE(cusparseDestroySpMat(descA));
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE(cusparseDestroyDnMat(descB));
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CHECK_CUSPARSE(cusparseDestroyDnMat(descC));
CUDA_CHECK_RETURN( cudaFree(dBuffer) ); CUDA_CHECK_RETURN(cudaFree(dBuffer));
} }
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) template <typename T, int BITS>
{ void spmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(
CUDA_CHECK_RETURN(cudaPeekAtLastError()); max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits) {
{
int num_blocks = (m+31)/32; int num_blocks = (m + 31) / 32;
if(bits == 32) if (bits == 32)
gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); gemm_device<T, 32, 32><<<num_blocks, 32, 0, 0>>>(m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16) if (bits == 16)
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); gemm_device<T, 16, 160><<<num_blocks, 160, 0, 0>>>(m, n, k, A, B, out, lda, ldb, ldc);
} }
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) template <typename T>
{ void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
) {
int num_blocks = (m+31)/32; int num_blocks = (m + 31) / 32;
kgemm_4bit_inference<T, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); kgemm_4bit_inference<T, 96><<<num_blocks, 96, 0, 0>>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
} }
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) template <typename T, int BITS>
{ void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, cudaStream_t stream
) {
int num_blocks = (m+3)/4; int num_blocks = (m + 3) / 4;
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); kgemm_4bit_inference_naive<T, 128, BITS>
CUDA_CHECK_RETURN(cudaPeekAtLastError()); <<<num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int FUNC> void func(T *A, T *B, T value, long n) template <typename T, int FUNC> void func(T* A, T* B, T value, long n) {
{ int threads = 512;
int threads = 512; int blocks = n / threads;
int blocks = n/threads; blocks = n % threads == 0 ? blocks : blocks + 1;
blocks = n % threads == 0 ? blocks : blocks + 1; blocks = blocks > 65535 ? 65535 : blocks;
blocks = blocks > 65535 ? 65535 : blocks; kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n);
kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n); CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================
template void func<float, FILL>(float *A, float *B, float value, long n); template void func<float, FILL>(float* A, float* B, float value, long n);
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n); template void func<unsigned char, FILL>(unsigned char* A, unsigned char* B, unsigned char value, long n);
template void func<float, ARANGE>(float *A, float *B, float value, long n); template void func<float, ARANGE>(float* A, float* B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n); template void func<float, _MUL>(float* A, float* B, float value, long n);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template void gemm_4bit_inference<half>(
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); );
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); template void gemm_4bit_inference_naive<half, 16>(
int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); int ldc, int blocksize, cudaStream_t stream
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); );
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
);
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template void gemm_4bit_inference_naive<float, 32>(
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); int ldc, int blocksize, cudaStream_t stream
);
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); // template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc,
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); // int bits);
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void gemm_host<half>(int m, int n, int k, half* A, half* B, half* out, int lda, int ldb, int ldc, int bits);
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void spmm_coo_very_sparse_naive<half, 16>(
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); );
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void spmm_coo_very_sparse_naive<signed char, 8>(
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); template int igemmlt<32, 0>(
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); int lda, int ldb, int ldc, cudaStream_t stream
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); );
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); template int igemmlt<8, 0>(
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); int lda, int ldb, int ldc, cudaStream_t stream
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); );
template int igemmlt<8, 1>(
#define MAKE_optimizer32bit(name, gtype) \ cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \ int lda, int ldb, int ldc, cudaStream_t stream
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ );
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \ template void quantizeBlockwise<half, 1, General8bit>(
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
MAKE_optimizer32bit(ADAM, half) template void quantizeBlockwise<half, 0, General8bit>(
MAKE_optimizer32bit(ADAM, float) float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
MAKE_optimizer32bit(ADAM, __nv_bfloat16) );
MAKE_optimizer32bit(MOMENTUM, half) template void quantizeBlockwise<half, 0, FP4>(
MAKE_optimizer32bit(MOMENTUM, float) float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16) );
MAKE_optimizer32bit(RMSPROP, half) template void quantizeBlockwise<half, 0, NF4>(
MAKE_optimizer32bit(RMSPROP, float) float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
MAKE_optimizer32bit(RMSPROP, __nv_bfloat16) );
MAKE_optimizer32bit(LION, half) template void quantizeBlockwise<float, 1, General8bit>(
MAKE_optimizer32bit(LION, float) float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
MAKE_optimizer32bit(LION, __nv_bfloat16) );
MAKE_optimizer32bit(ADAGRAD, half) template void quantizeBlockwise<float, 0, General8bit>(
MAKE_optimizer32bit(ADAGRAD, float) float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16) );
MAKE_optimizer32bit(ADEMAMIX, half) template void quantizeBlockwise<float, 0, FP4>(
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
MAKE_optimizer32bit(ADEMAMIX, float) );
template void quantizeBlockwise<float, 0, NF4>(
#define MAKE_optimizerStatic8bit(name, gtype) \ float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ );
float *unorm, float max_unorm, float param_norm, \ template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(
float beta1, float beta2, \ float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
float eps, int step, float lr, \ const int n
float* quantiles1, float* quantiles2, \ );
float* max1, float* max2, float* new_max1, float* new_max2, \ template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(
float weight_decay, \ float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const float gnorm_scale, int n); \ const int n
);
MAKE_optimizerStatic8bit(ADAM, half) template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(
MAKE_optimizerStatic8bit(ADAM, float) float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
MAKE_optimizerStatic8bit(MOMENTUM, half) const int n
MAKE_optimizerStatic8bit(MOMENTUM, float) );
MAKE_optimizerStatic8bit(RMSPROP, half) template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(
MAKE_optimizerStatic8bit(RMSPROP, float) float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
MAKE_optimizerStatic8bit(LION, half) const int n
MAKE_optimizerStatic8bit(LION, float) );
MAKE_optimizerStatic8bit(ADAGRAD, half)
MAKE_optimizerStatic8bit(ADAGRAD, float) template void dequantizeBlockwise<float, General8bit>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
);
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void dequantizeBlockwise<float, FP4>(
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ );
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ template void dequantizeBlockwise<float, NF4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
MAKE_optimizerStatic8bitBlockwise(half, ADAM); );
template void dequantizeBlockwise<half, General8bit>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<half, FP4>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<half, NF4>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, FP4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, NF4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>( \
gtype * g, gtype * p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, \
const int n \
);
MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(
MOMENTUM, __nv_bfloat16
) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, __nv_bfloat16) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, half) MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, float)
#define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>( \
gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
const float gnorm_scale, int n \
);
MAKE_optimizerStatic8bit(ADAM, half) MAKE_optimizerStatic8bit(ADAM, float) MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit(
RMSPROP, half
) MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(LION, half) MAKE_optimizerStatic8bit(LION, float) MAKE_optimizerStatic8bit(ADAGRAD, half) MAKE_optimizerStatic8bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>( \
gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
);
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
...@@ -696,8 +727,8 @@ MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); ...@@ -696,8 +727,8 @@ MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(float* g, float* gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half* g, float* gnorm_vec, int step, const int n);
template int get_leading_dim<ROW>(int dim1, int dim2); template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2); template int get_leading_dim<COL>(int dim1, int dim2);
......
...@@ -3,41 +3,41 @@ ...@@ -3,41 +3,41 @@
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#ifndef ops_H #ifndef ops_H
#define ops_H #define ops_H
#include <assert.h>
#include <cstdint> #include <cstdint>
#include <stdio.h>
#include <iostream> #include <iostream>
#include <assert.h> #include <stdio.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <cusparse.h> #include <cusparse.h>
#include <vector>
#include <functional> #include <functional>
#include <vector>
#define CUDA_CHECK_RETURN(value) \
{ \
cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \
fprintf(stderr, "Error %s at line %d in file %s\n", cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} \
}
#define CUDA_CHECK_RETURN(value) { \ #define CHECK_CUSPARSE(value) \
cudaError_t _m_cudaStat = value; \ { \
if (_m_cudaStat != cudaSuccess) { \ cusparseStatus_t _m_cudaStat = value; \
fprintf(stderr, "Error %s at line %d in file %s\n", \ if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ fprintf( \
exit(1); \ stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__ \
} } ); \
exit(1); \
} \
#define CHECK_CUSPARSE(value) { \ }
cusparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
inline void checkCudaStatus(cudaError_t status) { inline void checkCudaStatus(cudaError_t status) {
if (status != cudaSuccess) { if (status != cudaSuccess) {
...@@ -49,140 +49,163 @@ inline void checkCudaStatus(cudaError_t status) { ...@@ -49,140 +49,163 @@ inline void checkCudaStatus(cudaError_t status) {
inline int checkCublasStatus(cublasStatus_t status) { inline int checkCublasStatus(cublasStatus_t status) {
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %d\n", status); printf("cuBLAS API failed with status %d\n", status);
//throw std::logic_error("cuBLAS API failed"); // throw std::logic_error("cuBLAS API failed");
return 1; return 1;
} }
return 0; return 0;
} }
typedef enum Operations_t typedef enum Operations_t {
{ ksmul = 0,
ksmul = 0,
} Operations_t; } Operations_t;
typedef enum Optimizer_t typedef enum Optimizer_t {
{ ADAM = 0,
ADAM = 0, MOMENTUM = 1,
MOMENTUM = 1, RMSPROP = 2,
RMSPROP = 2, LARS = 3,
LARS = 3, ADAGRAD = 4,
ADAGRAD = 4, LION = 5,
LION = 5, ADEMAMIX = 6
ADEMAMIX = 6
} Optimizer_t; } Optimizer_t;
typedef enum Transform_t typedef enum Transform_t {
{ ROW = 0,
ROW = 0, COL = 1,
COL = 1, COL32 = 2,
COL32 = 2, COL_TURING = 3,
COL_TURING = 3, COL_AMPERE = 4,
COL_AMPERE = 4,
} Transform_t; } Transform_t;
typedef enum DataType_t typedef enum DataType_t {
{ General8bit = 0,
General8bit = 0, FP4 = 1,
FP4 = 1, NF4 = 2,
NF4 = 2,
} DataType_t; } DataType_t;
typedef enum Funcs_t typedef enum Funcs_t {
{ FILL = 0,
FILL = 0, ARANGE = 1,
ARANGE = 1, _MUL = 2,
_MUL = 2,
} Funcs_t; } Funcs_t;
class Context class Context {
{ public:
public: cublasHandle_t m_handle;
cublasHandle_t m_handle;
Context()
{
cublasHandle_t handle;
cublasCreate_v2(&handle);
m_handle = handle;
}
Context() {
cublasHandle_t handle;
cublasCreate_v2(&handle);
m_handle = handle;
}
}; };
class ContextLt class ContextLt {
{ public:
public: cublasLtHandle_t m_handle;
cublasLtHandle_t m_handle;
ContextLt()
{
cublasLtHandle_t handle;
cublasLtCreate(&handle);
m_handle = handle;
}
ContextLt() {
cublasLtHandle_t handle;
cublasLtCreate(&handle);
m_handle = handle;
}
}; };
class ContextCusparse class ContextCusparse {
{ public:
public: cusparseHandle_t m_handle;
cusparseHandle_t m_handle;
ContextCusparse()
{
cusparseHandle_t handle;
cusparseCreate(&handle);
m_handle = handle;
}
ContextCusparse() {
cusparseHandle_t handle;
cusparseCreate(&handle);
m_handle = handle;
}
}; };
void quantize(float *code, float *A, unsigned char *out, int n); void quantize(float* code, float* A, unsigned char* out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream); void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream);
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template <typename T, int STOCHASTIC, int DATA_TYPE>
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream); void quantizeBlockwise(
float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, );
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, template <typename T, int DATA_TYPE>
float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, void dequantizeBlockwise(
int step, float lr, const float gnorm_scale, bool skip_zeros, int n); float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, cudaStream_t stream
);
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm, template <typename T, int OPTIMIZER>
float beta1, float beta2, void optimizer32bit(
float eps, int step, float lr, T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2,
float* quantiles1, float* quantiles2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale,
float* max1, float* max2, float* new_max1, float* new_max2, bool skip_zeros, int n
float weight_decay, );
const float gnorm_scale, int n);
template <typename T, int OPTIMIZER>
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, void optimizerStatic8bit(
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1,
bool skip_zeros, int n); float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n
);
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
template <typename T, int OPTIMIZER>
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void optimizerStatic8bitBlockwise(
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
long long int strideA, long long int strideB, long long int strideC, int batchCount); float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, bool skip_zeros, int n
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); );
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); template <typename T> void percentileClipping(T* g, float* gnorm_vec, int step, const int n);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); void gemmex(
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); );
void strided_gemmex(
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); );
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template <int DTYPE_OUT, int SCALE_ROWS>
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); int igemmlt(
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
template <typename T, int FUNC> void func(T *A, T *B, T value, long n); );
void cutlass_igemm(
bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc
);
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
);
void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
);
void spmm_coo(
cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
);
template <typename T, int BITS>
void spmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);
void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits);
template <typename T>
void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
);
template <typename T, int BITS>
void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, cudaStream_t stream
);
template <typename T, int FUNC> void func(T* A, T* B, T value, long n);
#endif #endif
...@@ -20,39 +20,60 @@ ...@@ -20,39 +20,60 @@
#if BUILD_CUDA #if BUILD_CUDA
//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) // void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); } //{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) void gemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) {
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); } gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16);
}
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) void gemm_4bit_inference(
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
) {
gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void gemm_4bit_inference_naive_fp16(
{ gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void gemm_4bit_inference_naive_bf16(
{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<__nv_bfloat16, 16>(
m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream
);
}
void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void gemm_4bit_inference_naive_fp32(
{ gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \ void fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { func<ctype, FUNC>(A, B, value, n); }
MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
#define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_grad_##gbits( \
void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
const float beta1, const float beta2, const float beta3, const float alpha, \ const float weight_decay, const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n \
const float eps, const float weight_decay, \ ) { \
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ optimizer32bit<gtype, oname>( \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \
lr, gnorm_scale, skip_zeros, n \
); \
}
MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(momentum, MOMENTUM, half, 16)
...@@ -70,19 +91,18 @@ MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32) ...@@ -70,19 +91,18 @@ MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32)
MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16) MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16)
MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16) MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
#define MAKE_FUNC8(fname, oname, gtype, gbits) \ void fname##_static_8bit_grad_##gbits( \
void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
float *unorm, float max_unorm, float param_norm, \ float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float beta1, float beta2, \ float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
float eps, int step, float lr, \ float gnorm_scale, int n \
float* quantiles1, float* quantiles2, \ ) { \
float* max1, float* max2, float* new_max1, float* new_max2, \ optimizerStatic8bit<gtype, oname>( \
float weight_decay, float gnorm_scale, int n) \ g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \
{ \ max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \
optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ ); \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ }
} \
MAKE_FUNC8(adam, ADAM, float, 32) MAKE_FUNC8(adam, ADAM, float, 32)
MAKE_FUNC8(adam, ADAM, half, 16) MAKE_FUNC8(adam, ADAM, half, 16)
...@@ -93,11 +113,17 @@ MAKE_FUNC8(rmsprop, RMSPROP, half, 16) ...@@ -93,11 +113,17 @@ MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
MAKE_FUNC8(lion, LION, float, 32) MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16) MAKE_FUNC8(lion, LION, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ void fname##_8bit_blockwise_grad_##gbits( \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
) { \
optimizerStatic8bitBlockwise<gtype, optim_name>( \
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \
weight_decay, gnorm_scale, skip_zeros, n \
); \
}
MAKE_BLOCKWISE8(adam, ADAM, half, fp16) MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
...@@ -118,239 +144,511 @@ MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16) ...@@ -118,239 +144,511 @@ MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
void percentileClipping_g32(float* g, float* gnorm_vec, int step, const int n) {
percentileClipping<float>(g, gnorm_vec, step, n);
}
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); } void percentileClipping_g16(half* g, float* gnorm_vec, int step, const int n) {
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); } percentileClipping<half>(g, gnorm_vec, step, n);
}
void quantizeBlockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16_fp4(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16_nf4(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp32_fp4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp32_nf4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void dequantizeBlockwise_fp16(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp16_fp4(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp16_nf4(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } void dequantizeBlockwise_fp32(
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } ) {
dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } void dequantizeBlockwise_fp32_fp4(
void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } ) {
dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } void dequantizeBlockwise_fp32_nf4(
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } ) {
dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n, stream); } \ void dequantizeBlockwise_bf16(
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n, stream); } \ float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n, stream); } \ ) {
dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_fp4(
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n, stream); } float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n, stream); } ) {
dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_nf4(
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } ) {
dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { int igemmlt_32(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
int igemmlt_8(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
int igemmlt_8_rowscale(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) void spmm_coo_very_sparse_naive_fp16(
{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive<half, 16>(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) void spmm_coo_very_sparse_naive_int8(
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive<signed char, 8>(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
#endif #endif
extern "C" extern "C" {
{
#if BUILD_CUDA #if BUILD_CUDA
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); }
void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); }
void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); } dequantize(code, A, out, n, stream);
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); } }
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_fp16_fp4(
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } ) {
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream);
}
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp16(
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); } dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream);
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } }
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_fp16_nf4(
void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } ) {
void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } void cquantize_blockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n);
}
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ void cquantize_blockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n);
const float beta1, const float beta2, const float beta3, const float alpha, \ }
const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ void cquantize_blockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n);
}
MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, fp16) void cquantize_blockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
MAKE_CFUNC32(adam, __nv_bfloat16, bf16) quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n);
MAKE_CFUNC32(momentum, float, 32) }
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32) void cquantize_blockwise_fp32_fp4(
MAKE_CFUNC32(rmsprop, half, 16) float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n
MAKE_CFUNC32(lion, float, fp32) ) {
MAKE_CFUNC32(lion, half, fp16) quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n);
MAKE_CFUNC32(lion, __nv_bfloat16, bf16) }
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16) void cquantize_blockwise_fp32_nf4(
MAKE_CFUNC32(ademamix, float, fp32) float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n
MAKE_CFUNC32(ademamix, half, fp16) ) {
MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n);
}
#define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ void cdequantize_blockwise_fp32(
float *unorm, float max_unorm, float param_norm, \ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
float beta1, float beta2, \ ) {
float eps, int step, float lr, \ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream);
float* quantiles1, float* quantiles2, \ }
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \ void cdequantize_blockwise_fp32_fp4(
{ \ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ ) {
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream);
} \ }
MAKE_CFUNC8(adam, float, 32) void cdequantize_blockwise_fp32_nf4(
MAKE_CFUNC8(adam, half, 16) float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
MAKE_CFUNC8(momentum, float, 32) ) {
MAKE_CFUNC8(momentum, half, 16) dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream);
MAKE_CFUNC8(rmsprop, float, 32) }
MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32) void cquantize_blockwise_bf16(
MAKE_CFUNC8(lion, half, 16) float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n);
void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ }
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ void cquantize_blockwise_bf16_fp4(
{ fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n);
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) }
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) void cquantize_blockwise_bf16_nf4(
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) ) {
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n);
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) }
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) void cdequantize_blockwise_bf16(
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) ) {
MAKE_CBLOCKWISE8(lion, LION, half, fp16) dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream);
MAKE_CBLOCKWISE8(lion, LION, float, fp32) }
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16) void cdequantize_blockwise_bf16_fp4(
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32) float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) ) {
dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream);
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void cdequantize_blockwise_bf16_nf4(
void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
{ gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); } ) {
void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream);
long strideA, long strideB, long strideC, int batchCount) }
{ strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }
#define MAKE_CFUNC32(name, gtype, gbits) \
Context *get_context(){ return new Context(); } void c##name##32bit_grad_##gbits( \
ContextCusparse *get_cusparse(){ return new ContextCusparse(); } gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, \
return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); const int n \
} ) { \
int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { name##32bit_grad_##gbits( \
return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \
} lr, gnorm_scale, skip_zeros, n \
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { ); \
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); }
}
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream) MAKE_CFUNC32(adam, float, fp32)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } MAKE_CFUNC32(adam, half, fp16)
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
getRowStats(A, rowStats, threshold, rows, cols, stream); MAKE_CFUNC32(momentum, float, 32)
} MAKE_CFUNC32(momentum, half, 16)
void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { MAKE_CFUNC32(rmsprop, float, 32)
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); MAKE_CFUNC32(rmsprop, half, 16)
} MAKE_CFUNC32(lion, float, fp32)
MAKE_CFUNC32(lion, half, fp16)
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
{ spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)
void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) MAKE_CFUNC32(ademamix, float, fp32)
{ spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } MAKE_CFUNC32(ademamix, half, fp16)
MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16)
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } #define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_grad_##gbits( \
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) float gnorm_scale, int n \
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } ) { \
name##_static_8bit_grad_##gbits( \
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \
); \
void *cget_managed_ptr(size_t bytes) }
{
void *ptr; MAKE_CFUNC8(adam, float, 32)
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); MAKE_CFUNC8(adam, half, 16)
CUDA_CHECK_RETURN(cudaPeekAtLastError()); MAKE_CFUNC8(momentum, float, 32)
MAKE_CFUNC8(momentum, half, 16)
return ptr; MAKE_CFUNC8(rmsprop, float, 32)
} MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32)
void cprefetch(void *ptr, size_t bytes, int device) MAKE_CFUNC8(lion, half, 16)
{
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
int hasPrefetch = 0; void c##fname##_8bit_blockwise_grad_##gbits( \
CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
if (hasPrefetch == 0) return; float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); ) { \
CUDA_CHECK_RETURN(cudaPeekAtLastError()); fname##_8bit_blockwise_grad_##gbits( \
} p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \
weight_decay, gnorm_scale, skip_zeros, n \
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ ); \
void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ }
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
{ gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
void cpercentile_clipping_g32(float* g, float* gnorm_vec, int step, const int n) {
percentileClipping_g32(g, gnorm_vec, step, n);
}
void cpercentile_clipping_g16(half* g, float* gnorm_vec, int step, const int n) {
percentileClipping_g16(g, gnorm_vec, step, n);
}
void cigemm(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
) {
gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc);
}
void cbatched_igemm(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc, long strideA, long strideB, long strideC, int batchCount
) {
strided_gemmex(
context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount
);
}
Context* get_context() { return new Context(); }
ContextCusparse* get_cusparse() { return new ContextCusparse(); }
int cigemmlt_32(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_32((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int cigemmlt_8(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_8((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int cigemmlt_8_rowscale(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_8_rowscale((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
void cdequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
) {
dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream);
}
void cget_row_stats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
getRowStats(A, rowStats, threshold, rows, cols, stream);
}
void cint8_vector_quant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
) {
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
}
void cspmm_coo(
ContextCusparse* context, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
) {
spmm_coo(
(cusparseHandle_t)context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C,
transposed_B
);
}
void cspmm_coo_very_sparse_naive_fp16(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive_fp16(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
void cspmm_coo_very_sparse_naive_int8(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive_int8(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
// void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
void cgemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) {
gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc);
}
void cgemm_4bit_inference(
int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
) {
gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
void* cget_managed_ptr(size_t bytes) {
void* ptr;
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));
CUDA_CHECK_RETURN(cudaPeekAtLastError());
return ptr;
}
void cprefetch(void* ptr, size_t bytes, int device) {
int hasPrefetch = 0;
CUDA_CHECK_RETURN(
cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)
); // 40ns overhead
if (hasPrefetch == 0)
return;
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void c##fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { fname##_##type_name(A, B, value, n); }
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
void cgemm_4bit_inference_naive_fp16(
int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void cgemm_4bit_inference_naive_bf16(
int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void cgemm_4bit_inference_naive_fp32(
int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
#endif #endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_cpu_fp32(
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n
) {
quantize_cpu(code, A, absmax, out, blocksize, n);
}
void cdequantize_blockwise_cpu_fp32(
float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n
) {
dequantize_cpu(code, A, absmax, out, blocksize, n);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment