Unverified Commit 4955d136 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Apply clang-format rules (#1678)

parent 61db0859
......@@ -26,10 +26,12 @@ void quantize_block(const quantize_block_args& args) {
if (idx < 255) {
float dist_left = fabs(normed_value - (args.code[idx]));
float dist_right = fabs(normed_value - (args.code[idx + 1]));
if (dist_right < dist_left) { idx += 1; }
if (dist_right < dist_left) {
idx += 1;
}
}
// 5. store index
args.out[i] = (unsigned char) idx;
args.out[i] = (unsigned char)idx;
}
}
......@@ -28,7 +28,8 @@
// The maximum number of resident threads per SM varies by arch.
// For A100/H100 and all prior to Turing, it is 2048, which allows
// for 2 full blocks of 1024 threads per SM.
// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
// Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
#if __CUDA_ARCH__ == 750
#define BNB_MAX_THREADS_PER_SM 1024
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
......
......@@ -5,21 +5,18 @@
using namespace BinSearch;
#define BLOCK_SIZE 16384
struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
BinAlgo<Scalar, float, Direct2>* bin_searcher;
float* code;
float* A;
float* absmax;
unsigned char* out;
long long block_end;
long long block_idx;
long long threadidx;
long long blocksize;
};
void quantize_block(const quantize_block_args& args);
#endif
......@@ -4,7 +4,7 @@
using namespace BinSearch;
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) {
void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n) {
for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items;
......@@ -13,8 +13,7 @@ void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, lo
}
}
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
{
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) {
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
......@@ -28,15 +27,13 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
int thread_wave_size = 256;
// we chunk the threads into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
{
for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) {
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
std::vector<std::thread> threads(valid_chunks);
std::vector<quantize_block_args> args(valid_chunks);
int chunks_processed = 0;
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
{
for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) {
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items;
......@@ -53,11 +50,12 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
chunks_processed += 1;
if(chunks_processed == valid_chunks){ break; }
if (chunks_processed == valid_chunks) {
break;
}
}
for (int i = 0; i < valid_chunks; i++)
threads[i].join();
}
}
......@@ -4,7 +4,7 @@
#include <iostream>
#include <stdio.h>
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n);
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n);
void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n);
#endif
......@@ -3,26 +3,42 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "kernels.cuh"
#include "common.cuh"
#include <cuda_fp16.h>
#include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cub/block/block_load.cuh>
#include "kernels.cuh"
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cuda_fp16.h>
#include <math_constants.h>
#include <mma.h>
#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
__device__ static float nf4_data[16] = {
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0
};
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
__device__ float atomicMax(float* address, float val) {
......@@ -30,42 +46,35 @@ __device__ float atomicMax(float* address, float val) {
int old = *address_as_i, assumed;
do {
assumed = old;
old = atomicCAS(
reinterpret_cast<int*>(address), assumed,
__float_as_int(fmaxf(val, __int_as_float(assumed))));
old = atomicCAS(reinterpret_cast<int*>(address), assumed, __float_as_int(fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
{
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) {
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 111
return 0.25000000f*absmax*sign; // 1111
else
return 0.16666667f*absmax*sign; // 1110
else
if((val & 0b0001) == 1) // 110
return 0.50000000f*absmax*sign; // 1101
if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 111
return 0.25000000f * absmax * sign; // 1111
else
return 0.33333333f*absmax*sign; // 1100
return 0.16666667f * absmax * sign; // 1110
else if ((val & 0b0001) == 1) // 110
return 0.50000000f * absmax * sign; // 1101
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 1.00000000f*absmax*sign; // 1011
return 0.33333333f * absmax * sign; // 1100
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 1.00000000f * absmax * sign; // 1011
else
return 0.66666667f*absmax*sign; // 1010
return 0.66666667f * absmax * sign; // 1010
else if ((val & 0b0001) == 1) // 100
return 5.208333333e-03f * absmax * sign; // 1001
else
if((val & 0b0001) == 1) // 100
return 5.208333333e-03f*absmax*sign; // 1001
else
return 0.00000000f*absmax*sign; // 1000
return 0.00000000f * absmax * sign; // 1000
}
__device__ unsigned char dQuantizeFP4(float x)
{
__device__ unsigned char dQuantizeFP4(float x) {
// FP4 with bias of 3
// first bit is a sign
// subnormals
......@@ -78,7 +87,6 @@ __device__ unsigned char dQuantizeFP4(float x)
// 0b010 = 8
// 0b011 = 12
// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assume input data is in [-1.0, 1.0]
......@@ -89,148 +97,124 @@ __device__ unsigned char dQuantizeFP4(float x)
int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if(x > 0.29166667f)
if( x > 0.583333f)
if( x > 0.8333333f)
return 0b0011+sign;
else
return 0b0010+sign;
if (x > 0.29166667f)
if (x > 0.583333f)
if (x > 0.8333333f)
return 0b0011 + sign;
else
if(x > 0.4166667f)
return 0b101+sign;
return 0b0010 + sign;
else if (x > 0.4166667f)
return 0b101 + sign;
else
return 0b100+sign;
return 0b100 + sign;
else if (x > 0.0859375f)
if (x > 0.20833333f)
return 0b0111 + sign;
else
if(x > 0.0859375f)
if(x > 0.20833333f)
return 0b0111+sign;
return 0b0110 + sign;
else if (x > 0.00260417f)
return 0b0001 + sign;
else
return 0b0110+sign;
else
if(x > 0.00260417f)
return 0b0001+sign;
else
return 0b0000+sign;
return 0b0000 + sign;
}
__device__ __forceinline__ float dDequantizeNF4(unsigned char val)
{
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) {
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
if ((val & 0b1000) == 8)
if ((val & 0b0100) == 4) // 1
if ((val & 0b0010) == 2) // 11
if ((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else
if((val & 0b0001) == 1) // 110
else if ((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else
if((val & 0b0001) == 1) // 100
else if ((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;
else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
else if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else
if((val & 0b0001) == 1) // 010
else if ((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
else if ((val & 0b0010) == 2) // 00
if ((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else
if((val & 0b0001) == 1) // 000
else if ((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}
__device__ unsigned char dQuantizeNF4(float x)
{
__device__ unsigned char dQuantizeNF4(float x) {
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if(x > 0.03979014977812767f)
if(x > 0.3893125355243683f) // 1
if(x > 0.6427869200706482f) // 11
if(x > 0.8614784181118011f) // 111
if (x > 0.03979014977812767f)
if (x > 0.3893125355243683f) // 1
if (x > 0.6427869200706482f) // 11
if (x > 0.8614784181118011f) // 111
return 0b1111;
else
return 0b1110;
else
if(x > 0.5016634166240692f) // 110
else if (x > 0.5016634166240692f) // 110
return 0b1101;
else
return 0b1100;
else
if(x > 0.2035212516784668f) // 10
if(x > 0.2920137718319893f) // 101
else if (x > 0.2035212516784668f) // 10
if (x > 0.2920137718319893f) // 101
return 0b1011;
else
return 0b1010;
else
if(x > 0.1202552504837513f) // 100
else if (x > 0.1202552504837513f) // 100
return 0b1001;
else
return 0b1000;
else
if(x > -0.33967943489551544f) // 0
if(x > -0.13791173323988914f) // 01
if(x > -0.045525018125772476f) // 011
else if (x > -0.33967943489551544f) // 0
if (x > -0.13791173323988914f) // 01
if (x > -0.045525018125772476f) // 011
return 0b0111;
else
return 0b0110;
else
if(x > -0.23460740596055984f) // 010
else if (x > -0.23460740596055984f) // 010
return 0b0101;
else
return 0b0100;
else
if(x > -0.6106329262256622f) // 00
if(x > -0.4599952697753906f) // 001
else if (x > -0.6106329262256622f) // 00
if (x > -0.4599952697753906f) // 001
return 0b0011;
else
return 0b0010;
else
if(x > -0.8480964004993439f) // 000
else if (x > -0.8480964004993439f) // 000
return 0b0001;
else
return 0b0000;
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template <typename T> __device__ int sgn(T val)
{
return (T(0) < val) - (val < T(0));
}
template <typename T> __device__ int sgn(T val) { return (T(0) < val) - (val < T(0)); }
template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
template <int STOCHASTIC> __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) {
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
......@@ -240,71 +224,60 @@ __device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
float val = smem_code[pivot];
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
for (int i = 64; i > 0; i >>= 1) {
if (x > val) {
lower_pivot = pivot;
lower = val;
pivot+=i;
}
else
{
pivot += i;
} else {
upper_pivot = pivot;
upper = val;
pivot-=i;
pivot -= i;
}
val = smem_code[pivot];
}
if(upper_pivot == 255)
if (upper_pivot == 255)
upper = smem_code[upper_pivot];
if(lower_pivot == 0)
if (lower_pivot == 0)
lower = smem_code[lower_pivot];
if(!STOCHASTIC)
{
if(x > val)
{
float midpoint = (upper+val)*0.5f;
if(x > midpoint)
{
if (!STOCHASTIC) {
if (x > val) {
float midpoint = (upper + val) * 0.5f;
if (x > midpoint) {
return upper_pivot;
}
else
} else
return pivot;
}
else
{
float midpoint = (lower+val)*0.5f;
if(x < midpoint)
} else {
float midpoint = (lower + val) * 0.5f;
if (x < midpoint)
return lower_pivot;
else
return pivot;
}
}
} else {
if (x > val) {
float dist_to_upper = fabsf(upper - x);
float dist_full = upper - val;
if (rand >= dist_to_upper / dist_full)
return upper_pivot;
else
{
if(x > val)
{
float dist_to_upper = fabsf(upper-x);
float dist_full = upper-val;
if(rand >= dist_to_upper/dist_full) return upper_pivot;
else return pivot;
}
return pivot;
} else {
float dist_to_lower = fabsf(lower - x);
float dist_full = val - lower;
if (rand >= dist_to_lower / dist_full)
return lower_pivot;
else
{
float dist_to_lower = fabsf(lower-x);
float dist_full = val-lower;
if(rand >= dist_to_lower/dist_full) return lower_pivot;
else return pivot;
return pivot;
}
}
}
template <int SIGNED>
__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x)
{
__device__ __forceinline__ unsigned char
quantize_2D(float* __restrict__ quadrants, float* __restrict__ const smem_code, float x) {
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
......@@ -317,56 +290,48 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran
int offset = 1;
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
for (int i = 64; i > 0; i >>= 1) {
if (x > val) {
lower_pivot = pivot;
lower = val;
pivot+=i;
//val = i == 64 ? quadrants[2] : smem_code[pivot];
pivot += i;
// val = i == 64 ? quadrants[2] : smem_code[pivot];
local_pivot += offset;
}
else
{
} else {
upper_pivot = pivot;
upper = val;
pivot-=i;
//val = i == 64 ? quadrants[0] : smem_code[pivot];
pivot -= i;
// val = i == 64 ? quadrants[0] : smem_code[pivot];
local_pivot -= offset;
}
val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot];
offset -= 1;
}
if(x > val)
{
midpoint = (upper+val)*0.5f;
if(x > midpoint)
if (x > val) {
midpoint = (upper + val) * 0.5f;
if (x > midpoint)
return upper_pivot;
else
return pivot;
}
else
{
midpoint = (lower+val)*0.5f;
if(x < midpoint)
} else {
midpoint = (lower + val) * 0.5f;
if (x < midpoint)
return lower_pivot;
else
return pivot;
}
}
__launch_bounds__(TH, 4)
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
{
const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK;
__launch_bounds__(TH, 4) __global__
void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n) {
const int n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
int valid_items = (blockIdx.x + 1 == gridDim.x) ? n - (blockIdx.x * NUM_BLOCK) : NUM_BLOCK;
const int base_idx = (blockIdx.x * NUM_BLOCK);
float vals[NUM];
unsigned char qvals[NUM];
//const int lane_id = threadIdx.x % 2;
// const int lane_id = threadIdx.x % 2;
typedef cub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
......@@ -376,16 +341,13 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
__shared__ float smem_code[256];
//__shared__ float smem_code[2][257];
if(threadIdx.x < 256)
{
if (threadIdx.x < 256) {
smem_code[threadIdx.x] = code[threadIdx.x];
//smem_code[0][threadIdx.x] = code[threadIdx.x];
//smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x];
// smem_code[0][threadIdx.x] = code[threadIdx.x];
// smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x];
}
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK)
{
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_BLOCK) {
// number of values already processed in blocks +
// number of values already processed in this block +
// rand_offset % mod value
......@@ -394,9 +356,8 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
__syncthreads();
LoadFloat(loadf).Load(&(A[i]), vals, valid_items);
#pragma unroll 4
for(int j = 0; j < NUM; j++)
#pragma unroll 4
for (int j = 0; j < NUM; j++)
qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]);
__syncthreads();
......@@ -404,25 +365,30 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
}
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
template <typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
__global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
) {
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH];
// float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<
unsigned char, BLOCK_SIZE / NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH,
cub::BLOCK_STORE_WARP_TRANSPOSE>
StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE / NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf;
......@@ -431,12 +397,11 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
__shared__ float smem_code[256];
__shared__ float smem_absmax_value[1];
if(DATA_TYPE == General8bit)
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
if (DATA_TYPE == General8bit)
for (int i = threadIdx.x; i < 256; i += blockDim.x)
smem_code[i] = code[i];
for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
for (int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_abs_max = -FLT_MAX;
......@@ -447,8 +412,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
// 2. broadcast local max
// 3. normalize inputs and quantize
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
......@@ -461,60 +426,57 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
local_abs_max = smem_absmax_value[0];
if(STOCHASTIC)
{
local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4);
if (STOCHASTIC) {
local_rand_idx = ((blockIdx.x * NUM_BLOCK) + (threadIdx.x * NUM) + rand_offset) % (1024 - 4);
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}
unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
switch (DATA_TYPE) {
case General8bit:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
if(!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
if (!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j]) * local_abs_max);
else
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j]) * local_abs_max);
}
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
qvals[j] = packed_4bit;
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
qvals[j] = packed_4bit;
}
break;
}
__syncthreads();
StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
StoreChar(storec).Store(
&(out[(DATA_TYPE > 0) ? i / 2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items + 1) / 2 : valid_items
);
}
}
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n)
{
template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) {
const int n_load = (gridDim.x * TILE_SIZE);
int valid_items_load = 0;
int valid_items_store = 0;
const int base_idx = (blockIdx.x * TILE_SIZE);
T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)];
T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)];
unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX;
......@@ -524,15 +486,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;
for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
{
if (DATA_TYPE > 0)
{
for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
if (DATA_TYPE > 0) {
valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i);
valid_items_store = min(TILE_SIZE * 2, n - i * 2);
}
else
{
} else {
valid_items_load = min(TILE_SIZE, n - i);
valid_items_store = valid_items_load;
}
......@@ -540,72 +498,62 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
// Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]);
local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) >> (31 - __clz(blocksize))]);
__syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
switch (DATA_TYPE)
{
switch (DATA_TYPE) {
case General8bit:
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++)
vals[j] = __ldg(&code[qvals[j]]) * local_abs_max;
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max;
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max;
vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max;
}
break;
}
__syncthreads();
StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store);
StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i * 2 : i]), vals, valid_items_store);
}
}
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
{
__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n) {
const unsigned int numThreads = blockDim.x * gridDim.x;
const int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
__shared__ float smem_code[256];
if(threadIdx.x < 256)
{
if (threadIdx.x < 256) {
smem_code[threadIdx.x] = code[threadIdx.x];
}
__syncthreads();
for (int i = idx;i < n; i += numThreads)
{
for (int i = idx; i < n; i += numThreads) {
out[i] = smem_code[A[i]];
}
}
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
) {
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0;
......@@ -614,12 +562,12 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float s1_vals[NUM_VALS];
float s2_vals[NUM_VALS];
const float correction1 = 1.0f/(1.0f - powf(beta1, step));
const float correction2 = 1.0f/(1.0f - powf(beta2, step));
const float correction1 = 1.0f / (1.0f - powf(beta1, step));
const float correction2 = 1.0f / (1.0f - powf(beta2, step));
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
......@@ -627,8 +575,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
typename BlockReduce::TempStorage reduce;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads();
......@@ -638,60 +585,56 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
#pragma unroll NUM_VALS
for (unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale * ((float)g_vals[j]);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
{
switch(OPTIMIZER)
{
#pragma unroll NUM_VALS
for (unsigned int j = 0; j < NUM_VALS; j++) {
switch (OPTIMIZER) {
case ADAM:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]));
s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j])));
s1_vals[j] *= correction1;
s2_vals[j] *= correction2;
s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
s1_vals[j] = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update
s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update)
break;
}
}
# pragma unroll NUM_VALS-1
for(unsigned int j = 1; j < NUM_VALS; j++)
#pragma unroll NUM_VALS - 1
for (unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j];
__syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]);
if(threadIdx.x == 0)
if (threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]);
__syncwarp();
}
}
#define NUM_PER_THREAD 4
template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
template <typename T, int OPTIMIZER>
__launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
) {
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) +
(n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0;
float update_scale = 0.0f;
T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD];
float s1_vals[NUM_PER_THREAD];
float s2_vals[NUM_PER_THREAD];
......@@ -700,18 +643,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
// TODO: Mark with [[maybe_unused]] after upgrade to min compiler.
float s3_vals[NUM_PER_THREAD];
const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1;
const float step_size = -lr * correction2 / correction1;
if(max_unorm > 0.0f)
{
if (max_unorm > 0.0f) {
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
if (update_scale > max_unorm * param_norm) {
update_scale = (max_unorm * param_norm) / update_scale;
} else {
update_scale = 1.0f;
}
} else {
update_scale = 1.0f;
}
else{ update_scale = 1.0f; }
typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
......@@ -726,9 +671,8 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
typename StoreFloat::TempStorage storef;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) {
valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
......@@ -746,15 +690,13 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
#pragma unroll 4
for (unsigned int j = 0; j < NUM_PER_THREAD; j++)
g_vals[j] = gnorm_scale * ((float)g_vals[j]);
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
switch(OPTIMIZER)
{
#pragma unroll 4
for (unsigned int j = 0; j < NUM_PER_THREAD; j++) {
switch (OPTIMIZER) {
case ADEMAMIX:
// m1 update: m1 = beta1 * m1 + (1-beta1) * g
s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]);
......@@ -765,11 +707,8 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
// nu update: nu = beta2 * nu + (1-beta2) * g^2
s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]);
p_vals[j] = (float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
);
p_vals[j] = (float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) /
((sqrtf(s2_vals[j]) / correction2) + eps));
if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
......@@ -777,14 +716,14 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
break;
case ADAM:
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) {
s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]));
s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j])));
p_vals[j] = ((float)p_vals[j]) +
(update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (eps * correction2))));
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
}
break;
}
......@@ -804,15 +743,13 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
}
}
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State(
T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
) {
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0;
......@@ -820,9 +757,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float s1_vals[NUM_VALS];
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
......@@ -830,8 +767,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
typename BlockReduce::TempStorage reduce;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads();
......@@ -839,72 +775,74 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
#pragma unroll NUM_VALS
for (unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale * ((float)g_vals[j]);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
{
switch(OPTIMIZER)
{
#pragma unroll NUM_VALS
for (unsigned int j = 0; j < NUM_VALS; j++) {
switch (OPTIMIZER) {
case MOMENTUM:
if(step == 1)
if (step == 1)
s1_vals[j] = (float)g_vals[j]; // state update
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * (float)g_vals[j]); // state update
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
s1_vals[j] =
s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value
s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]); // state update
s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value
s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm
break;
}
}
# pragma unroll
for(unsigned int j = 1; j < NUM_VALS; j++)
#pragma unroll
for (unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j];
__syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items);
if(threadIdx.x == 0)
if (threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]);
__syncwarp();
}
}
template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
template <typename T, int OPTIMIZER>
__launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State(
T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1,
const float beta2, const float eps, const float weight_decay, const int step, const float lr,
const float gnorm_scale, const bool skip_zeros, const int n
) {
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) +
(n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0;
float update_scale = 0.0f;
if(max_unorm > 0.0f)
{
if (max_unorm > 0.0f) {
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; }
else{ update_scale = 1.0f; }
if (update_scale > max_unorm * param_norm + eps) {
update_scale = (max_unorm * param_norm + eps) / update_scale;
} else {
update_scale = 1.0f;
}
} else {
update_scale = 1.0f;
}
else{ update_scale = 1.0f; }
T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD];
......@@ -924,9 +862,8 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
typename StoreFloat::TempStorage storef;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) {
valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
......@@ -935,40 +872,39 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
__syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
g_vals[j] = gnorm_scale*((float)g_vals[j]);
if(weight_decay > 0.0f)
g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
#pragma unroll 4
for (unsigned int j = 0; j < NUM_PER_THREAD; j++) {
g_vals[j] = gnorm_scale * ((float)g_vals[j]);
if (weight_decay > 0.0f)
g_vals[j] = (float)g_vals[j] + (((float)p_vals[j]) * weight_decay);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
switch(OPTIMIZER)
{
#pragma unroll 4
for (unsigned int j = 0; j < NUM_PER_THREAD; j++) {
if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) {
switch (OPTIMIZER) {
case MOMENTUM:
if(step == 1)
if (step == 1)
s1_vals[j] = (float)g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
p_vals[j] = ((float)p_vals[j]) + update_scale * (-lr * (s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
p_vals[j] =
((float)p_vals[j]) -
update_scale * (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_vals[j]))));
s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * ((float)g_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) -
update_scale * (lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps));
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) - lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps);
break;
}
}
......@@ -981,25 +917,21 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
}
}
#define NUM8BIT 16
#define NUM_THREADS 256
#define NUM_PER_BLOCK 4096
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n)
{
template <typename T, int OPTIMIZER>
__global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit2State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2,
float* unorm, const float beta1, const float beta2, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, const float gnorm_scale, const int n
) {
const int n_full = gridDim.x * NUM_PER_BLOCK;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
int valid_items =
n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK);
float g_val = 0.0f;
float local_max_s1 = -FLT_MAX;
float local_max_s2 = -FLT_MAX;
......@@ -1015,7 +947,6 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadUInt8::TempStorage loadc;
......@@ -1025,17 +956,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
__shared__ float smem_quantiles1[256];
__shared__ float smem_quantiles2[256];
if(threadIdx.x < 256)
{
if (threadIdx.x < 256) {
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x];
}
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS * gridDim.x * NUM8BIT) {
valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
......@@ -1044,37 +973,33 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128);
__syncthreads();
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
#pragma unroll 16
for (int j = 0; j < NUM8BIT; j++) {
g_val = g_vals[j];
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1;
s1_vals[j] += (1.0f-beta1)*g_val;
s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0] * beta1;
s1_vals[j] += (1.0f - beta1) * g_val;
local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
}
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
#pragma unroll 16
for (int j = 0; j < NUM8BIT; j++) {
g_val = g_vals[j];
g_val *= gnorm_scale;
s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2;
s2_vals[j] += (1.0f-beta2)*g_val*g_val;
s2_vals[j] = smem_quantiles2[r_c2[j]] * max2[0] * beta2;
s2_vals[j] += (1.0f - beta2) * g_val * g_val;
local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j]));
}
if(unorm != NULL)
{
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
if (unorm != NULL) {
#pragma unroll 16
for (int j = 0; j < NUM8BIT; j++) {
float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step));
float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step));
s1_vals[j] *= correction1;
s2_vals[j] *= correction2;
float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
local_unorm += update_val*update_val;
float update_val = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update
local_unorm += update_val * update_val;
}
}
}
......@@ -1083,17 +1008,17 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
__syncthreads();
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items);
if(unorm != NULL)
{
if (unorm != NULL) {
__syncthreads();
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
}
if(threadIdx.x == 0)
{
if (threadIdx.x == 0) {
atomicMax(&new_max1[0], local_max_s1);
atomicMax(&new_max2[0], local_max_s2);
if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); }
if (unorm != NULL) {
atomicAdd(&unorm[0], local_unorm);
}
}
}
......@@ -1101,20 +1026,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
#define NUM_THREADS2 1024
#define NUM_PER_BLOCK2 4096
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS2, 1)
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
template <typename T, int OPTIMIZER>
__global__ void __launch_bounds__(NUM_THREADS2, 1) kOptimizerStatic8bit2State(
T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n
) {
const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
int valid_items = 0;
float g_val = 0.0f;
......@@ -1122,19 +1042,22 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
float s2_vals[NUM_PER_THREAD2];
const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1;
//const float step_size = -lr*correction2/correction1;
float new_max_val1 = 1.0f/new_max1[0];
float new_max_val2 = 1.0f/new_max2[0];
const float step_size = -lr * correction2 / correction1;
// const float step_size = -lr*correction2/correction1;
float new_max_val1 = 1.0f / new_max1[0];
float new_max_val2 = 1.0f / new_max2[0];
float update_scale = 1.0f;
if(max_unorm > 0.0f)
{
if (max_unorm > 0.0f) {
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
if (update_scale > max_unorm * param_norm) {
update_scale = (max_unorm * param_norm) / update_scale;
} else {
update_scale = 1.0f;
}
} else {
update_scale = 1.0f;
}
else{ update_scale = 1.0f; }
unsigned char c1s[NUM_PER_THREAD2];
unsigned char c2s[NUM_PER_THREAD2];
......@@ -1156,19 +1079,17 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
typename StoreT::TempStorage storeh;
} temp_storage;
if(threadIdx.x < 512)
{
if(threadIdx.x < 256)
if (threadIdx.x < 512) {
if (threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
else
smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256];
smem_quantiles2[threadIdx.x - 256] = quantiles2[threadIdx.x - 256];
}
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) {
valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
......@@ -1177,42 +1098,42 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) {
continue;
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
#pragma unroll 4
for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) {
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[c1s[j]];
s1_vals[j] = s1_vals[j]*max1[0];
s1_vals[j] = s1_vals[j] * max1[0];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val));
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1);
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) {
if (s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
s2_vals[j] = smem_quantiles2[c2s[j]];
s2_vals[j] = s2_vals[j]*max2[0];
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2);
s2_vals[j] = s2_vals[j] * max2[0];
s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val));
c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j] * new_max_val2);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps))))));
if(weight_decay > 0.0f)
p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay));
#pragma unroll 4
for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) {
p_vals[j] = (T)(((float)p_vals[j]) +
((update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (correction2 * eps))))));
if (weight_decay > 0.0f)
p_vals[j] = update_scale * ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
}
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
......@@ -1224,22 +1145,16 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
}
}
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
const float weight_decay,
const float gnorm_scale, const int n)
{
template <typename T, int OPTIMIZER>
__global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit1State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1,
const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1,
float* new_max1, const float weight_decay, const float gnorm_scale, const int n
) {
const int n_full = gridDim.x * NUM_PER_BLOCK;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
int valid_items =
n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK);
float g_val = 0.0f;
float local_max_s1 = -FLT_MAX;
float local_unorm = 0.0f;
......@@ -1252,7 +1167,6 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadUInt8::TempStorage loadc;
......@@ -1261,42 +1175,39 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
__shared__ float smem_quantiles1[256];
if(threadIdx.x < 256)
if (threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS * NUM8BIT) {
valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
__syncthreads();
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
#pragma unroll 16
for (int j = 0; j < NUM8BIT; j++) {
g_val = g_vals[j];
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0];
switch(OPTIMIZER)
{
s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0];
switch (OPTIMIZER) {
case ADAGRAD:
case MOMENTUM:
if(step == 1)
if (step == 1)
s1_vals[j] = (float)g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
if(unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j];
s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]);
if (unorm != NULL)
local_unorm += s1_vals[j] * s1_vals[j];
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val));
break;
}
......@@ -1306,44 +1217,44 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
__syncthreads();
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); }
if(unorm != NULL)
{
if (threadIdx.x == 0) {
atomicMax(&new_max1[0], local_max_s1);
}
if (unorm != NULL) {
__syncthreads();
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); }
if (threadIdx.x == 0) {
atomicAdd(&unorm[0], local_unorm);
}
}
}
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(1024, 1)
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
template <typename T, int OPTIMIZER>
__global__ void __launch_bounds__(1024, 1) kOptimizerStatic8bit1State(
T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale,
const int n
) {
const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[NUM_PER_THREAD2];
float new_max_val1 = 1.0f/new_max1[0];
float new_max_val1 = 1.0f / new_max1[0];
float update_scale = 1.0f;
if(max_unorm > 0.0f)
{
if (max_unorm > 0.0f) {
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
if (update_scale > max_unorm * param_norm) {
update_scale = (max_unorm * param_norm) / update_scale;
} else {
update_scale = 1.0f;
}
} else {
update_scale = 1.0f;
}
else{ update_scale = 1.0f; }
unsigned char c1s[NUM_PER_THREAD2];
T p_vals[NUM_PER_THREAD2];
......@@ -1363,69 +1274,69 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
typename StoreT::TempStorage storeh;
} temp_storage;
if(threadIdx.x < 256)
if (threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) {
valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) {
continue;
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
#pragma unroll 4
for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) {
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
if (weight_decay > 0.0f) {
switch (OPTIMIZER) {
case ADAGRAD:
case MOMENTUM:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
g_val += ((float)p_vals[j]) * weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay);
break;
}
}
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
s1_vals[j] = smem_quantiles1[c1s[j]] * max1[0];
switch(OPTIMIZER){
switch (OPTIMIZER) {
case ADAGRAD:
case MOMENTUM:
if(step == 1)
if (step == 1)
s1_vals[j] = g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
p_vals[j] = ((float)p_vals[j]) + (-lr * update_scale * (s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
p_vals[j] =
((float)p_vals[j]) - (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_val))));
s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val));
p_vals[j] = ((float)p_vals[j]) - (lr * __fdividef(g_val, sqrtf(s1_vals[j]) + eps));
break;
}
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1);
// make sure state1 term has still the same sign after quantization
if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) {
if (s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
......@@ -1439,15 +1350,13 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
}
}
template<typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
template <typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n) {
const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
int valid_items = 0;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_VALS> BlockReduce;
typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename BlockReduce::TempStorage reduce;
......@@ -1455,64 +1364,42 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
T vals[NUM_VALS];
float local_sum = 0.0f;
for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE)
{
for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x * BLOCK_SIZE) {
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_sum = 0.0f;
__syncthreads();
LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f);
#pragma unroll NUM_VALS
for(int j = 0; j < NUM_VALS; j++)
local_sum += ((float)vals[j])*((float)vals[j]);
#pragma unroll NUM_VALS
for (int j = 0; j < NUM_VALS; j++)
local_sum += ((float)vals[j]) * ((float)vals[j]);
local_sum = BlockReduce(reduce).Sum(local_sum, valid_items);
if(threadIdx.x == 0)
{
if(step == 1)
{
if (threadIdx.x == 0) {
if (step == 1) {
// initialize with the same norm for all positions
//#pragma unroll 10
for(int j = 0; j < 100; j++)
// #pragma unroll 10
for (int j = 0; j < 100; j++)
atomicAdd(&gnorm_vec[j], local_sum);
}
else
} else
atomicAdd(&gnorm_vec[step % 100], local_sum);
}
}
}
#define LANES 2
#define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit2StateBlockwise(
T* p,
T* __restrict__ const g,
unsigned char* state1,
unsigned char* state2,
const float beta1,
const float beta2,
const float beta3,
const float alpha,
const float eps,
const int step,
const float lr,
float* __restrict__ const quantiles1,
float* __restrict__ const quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
const bool skip_zeros,
const int n
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2,
const float beta3, const float alpha, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n
) {
//const int n_full = n + (n%BLOCK_SIZE);
// const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
int valid_items = 0;
......@@ -1523,8 +1410,8 @@ kOptimizerStatic8bit2StateBlockwise(
// 2-5%
const float correction1 = 1.0f - __powf(beta1, step);
const float correction2 = sqrtf(1.0f -__powf(beta2, step));
const float step_size = __fdividef(-lr*correction2,correction1);
const float correction2 = sqrtf(1.0f - __powf(beta2, step));
const float step_size = __fdividef(-lr * correction2, correction1);
const int lane_id = threadIdx.x % LANES;
float new_local_abs_max1 = -FLT_MAX;
float new_local_abs_max2 = -FLT_MAX;
......@@ -1538,17 +1425,17 @@ kOptimizerStatic8bit2StateBlockwise(
T g_vals[N_PER_TH];
T p_vals[N_PER_TH];
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockLoad<T, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257];
__shared__ float smem_quantiles2[LANES][257];
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce3;
typedef cub::BlockReduce<float, BLOCK_SIZE / N_PER_TH> BlockReduce1;
typedef cub::BlockReduce<float, BLOCK_SIZE / N_PER_TH> BlockReduce2;
typedef cub::BlockReduce<float, BLOCK_SIZE / N_PER_TH> BlockReduce3;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ typename BlockReduce2::TempStorage reduce2;
__shared__ typename BlockReduce2::TempStorage reduce3;
......@@ -1562,30 +1449,27 @@ kOptimizerStatic8bit2StateBlockwise(
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
// init: 0.2 -> 0.23
// 0.23 -> 0.23
smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x];
# pragma unroll
for(unsigned int j = 1; j < LANES; j++)
{
#pragma unroll
for (unsigned int j = 1; j < LANES; j++) {
smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x];
}
__syncthreads();
#pragma unroll
for(int k = 0; k < QUAD; k++)
{
quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
#pragma unroll
for (int k = 0; k < QUAD; k++) {
quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)];
quadrants2[k] = smem_quantiles2[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)];
}
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
// loads: 0.23 -> 0.85/1.44
valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
__syncthreads();
......@@ -1605,31 +1489,27 @@ kOptimizerStatic8bit2StateBlockwise(
new_local_abs_max2 = -FLT_MAX;
new_local_abs_max3 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
{
s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
// update: 2.48/1.57 -> 2.51/1.60
#pragma unroll N_PER_TH
for (unsigned int j = 0; j < N_PER_TH; j++) {
if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) {
s2_vals[j] = smem_quantiles2[lane_id][c2s[j]] * absmax2[i / BLOCK_SIZE];
g_val = g_vals[j];
//float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps);
//g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val;
// float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps);
// g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val;
g_val *= gnorm_scale;
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val));
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE];
s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val));
if (OPTIMIZER == ADEMAMIX) {
// The absmax for the third state is appended to absmax1
s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE];
s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i) / BLOCK_SIZE];
s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val));
}
}
else
{
} else {
s1_vals[j] = 0.0f;
s2_vals[j] = 0.0f;
......@@ -1646,7 +1526,6 @@ kOptimizerStatic8bit2StateBlockwise(
}
}
// reduce: 2.51/1.60 -> 2.67/1.69
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max());
......@@ -1655,8 +1534,7 @@ kOptimizerStatic8bit2StateBlockwise(
new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max());
}
if(threadIdx.x == 0)
{
if (threadIdx.x == 0) {
smem_exchange1[0] = new_local_abs_max1;
smem_exchange2[0] = new_local_abs_max2;
......@@ -1667,17 +1545,14 @@ kOptimizerStatic8bit2StateBlockwise(
__syncthreads();
if(threadIdx.x == 0)
{
absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
absmax2[i/BLOCK_SIZE] = new_local_abs_max2;
if (threadIdx.x == 0) {
absmax1[i / BLOCK_SIZE] = new_local_abs_max1;
absmax2[i / BLOCK_SIZE] = new_local_abs_max2;
if (OPTIMIZER == ADEMAMIX) {
absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3;
}
absmax1[(n + i) / BLOCK_SIZE] = new_local_abs_max3;
}
else
{
} else {
new_local_abs_max1 = smem_exchange1[0];
new_local_abs_max2 = smem_exchange2[0];
......@@ -1688,25 +1563,23 @@ kOptimizerStatic8bit2StateBlockwise(
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
// reduce: 2.67/1.69 -> 2.67/1.70
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
//if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
{
// reduce: 2.67/1.69 -> 2.67/1.70
#pragma unroll N_PER_TH
for (unsigned int j = 0; j < N_PER_TH; j++) {
// if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) {
if (OPTIMIZER == ADEMAMIX) {
p_vals[j] = T((float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
));
p_vals[j] =
T((float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) /
((sqrtf(s2_vals[j]) / correction2) + eps)));
} else {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
p_vals[j] =
(T)(((float)p_vals[j]) +
((step_size * (__fdividef(s1_vals[j], (sqrtf(s2_vals[j]) + (correction2 * eps)))))));
}
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
}
}
......@@ -1714,25 +1587,24 @@ kOptimizerStatic8bit2StateBlockwise(
__syncthreads();
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2));
// quantizaztion: 2.67/1.70 -> 3.4/3.3
#pragma unroll N_PER_TH
for (unsigned int j = 0; j < N_PER_TH; j++) {
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1));
c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j], new_local_abs_max2));
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) {
if (s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
if (OPTIMIZER == ADEMAMIX) {
c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3));
c3s[j] =
quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j], new_local_abs_max3));
if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) {
c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1;
......@@ -1752,22 +1624,17 @@ kOptimizerStatic8bit2StateBlockwise(
}
}
#define LANES 2
#define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* absmax1,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n)
{
//const int n_full = n + (n%BLOCK_SIZE);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps,
const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n
) {
// const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
int valid_items = 0;
......@@ -1782,14 +1649,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
T g_vals[N_PER_TH];
T p_vals[N_PER_TH];
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockLoad<T, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257];
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
typedef cub::BlockReduce<float, BLOCK_SIZE / N_PER_TH> BlockReduce1;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ float smem_exchange1[1];
......@@ -1799,22 +1666,22 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
// init: 0.2 -> 0.23
// 0.23 -> 0.23
smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
# pragma unroll
for(unsigned int j = 1; j < LANES; j++)
#pragma unroll
for (unsigned int j = 1; j < LANES; j++)
smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
__syncthreads();
#pragma unroll
for(int k = 0; k < QUAD; k++)
quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
#pragma unroll
for (int k = 0; k < QUAD; k++)
quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
// loads: 0.23 -> 0.85/1.44
valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
__syncthreads();
......@@ -1826,47 +1693,45 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
new_local_abs_max1 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
// update: 2.48/1.57 -> 2.51/1.60
#pragma unroll N_PER_TH
for (unsigned int j = 0; j < N_PER_TH; j++) {
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) {
if (weight_decay > 0.0f) {
switch (OPTIMIZER) {
case MOMENTUM:
case ADAGRAD:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
g_val += ((float)p_vals[j]) * weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay);
break;
}
}
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE];
switch(OPTIMIZER)
{
switch (OPTIMIZER) {
case MOMENTUM:
if(step == 1)
if (step == 1)
s1_vals[j] = g_val;
else
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
s1_vals[j] = (s1_vals[j] * beta1) + g_val;
break;
case LION:
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update,
// before the momentum is updated by beta2
g_vals[j] = lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * g_val));
s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val));
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + (g_val*g_val);
s1_vals[j] = s1_vals[j] + (g_val * g_val);
break;
}
}
......@@ -1874,41 +1739,37 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
}
// reduce: 2.51/1.60 -> 2.67/1.69
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
if(threadIdx.x == 0)
if (threadIdx.x == 0)
smem_exchange1[0] = new_local_abs_max1;
__syncthreads();
if(threadIdx.x == 0)
absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
if (threadIdx.x == 0)
absmax1[i / BLOCK_SIZE] = new_local_abs_max1;
else
new_local_abs_max1 = smem_exchange1[0];
// reduce: 2.67/1.69 -> 2.67/1.70
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
switch(OPTIMIZER)
{
// reduce: 2.67/1.69 -> 2.67/1.70
#pragma unroll N_PER_TH
for (unsigned int j = 0; j < N_PER_TH; j++) {
if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) {
switch (OPTIMIZER) {
case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
p_vals[j] = ((float)p_vals[j]) - lr * (s1_vals[j]);
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
break;
case RMSPROP:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps));
break;
case ADAGRAD:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps));
break;
}
}
......@@ -1918,17 +1779,15 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
__syncthreads();
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
// quantizaztion: 2.67/1.70 -> 3.4/3.3
#pragma unroll N_PER_TH
for (unsigned int j = 0; j < N_PER_TH; j++) {
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1));
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) {
if (s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
......@@ -1945,9 +1804,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
// Outputs:
// rowStats [rows]
// out [rows, cols]
template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
template <typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
// Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
......@@ -2009,9 +1868,9 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
}
}
template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
template <typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols) {
using BlockReduceT = cub::BlockReduce<float, THREADS>;
// One block per row.
......@@ -2049,25 +1908,24 @@ __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshol
}
}
template __global__ void kgetRowStats<half, 1024, 0>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template __global__ void kgetRowStats<half, 1024, 1>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 0>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 1>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void
kgetRowStats<half, 1024, 0>(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template __global__ void
kgetRowStats<half, 1024, 1>(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 0>(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols
);
template __global__ void kInt8VectorQuant<half, 1024, 1>(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols
);
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
#define MM_DEQUANT_CONST 6.200012e-05f // 1.0f/(127.0f*127.0f)
template <int ITEMS_PER_THREAD, int THREADS>
__global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A,
float *__restrict__ const rowStats,
float *__restrict__ const colStats,
half *out,
half *__restrict__ const bias,
const int numRows,
const int numCols,
const int n
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
half* __restrict__ const bias, const int numRows, const int numCols, const int n
) {
const int n_out = numRows * numCols;
......@@ -2086,7 +1944,7 @@ __global__ void kdequant_mm_int32_fp16(
int row_idx, col_idx;
#pragma unroll ITEMS_PER_THREAD
#pragma unroll ITEMS_PER_THREAD
for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
row_idx = (block_offset + thread_offset + j) / numCols;
......@@ -2098,19 +1956,18 @@ __global__ void kdequant_mm_int32_fp16(
}
// Each block loads THREADS * ITEMS_PER_THREAD values from A
int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out
? THREADS * ITEMS_PER_THREAD
: n_out - block_offset;
int valid_items =
block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset;
LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0);
#pragma unroll ITEMS_PER_THREAD
#pragma unroll ITEMS_PER_THREAD
for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
local_output[j] = __float2half(
fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j])
);
}
#pragma unroll ITEMS_PER_THREAD
#pragma unroll ITEMS_PER_THREAD
for (int j = 0; j < ITEMS_PER_THREAD; j++) {
int outIdx = block_offset + thread_offset + j;
if (outIdx < n_out) {
......@@ -2119,12 +1976,15 @@ __global__ void kdequant_mm_int32_fp16(
}
}
#define DENORM 1.0f/127.0f
#define DENORM 1.0f / 127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
#define SMEM_SIZE 8 * 256
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
__global__ void kspmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
) {
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
// If a block finishes, the next one is scheduled. Since the last blocks like have fewer
......@@ -2139,12 +1999,12 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
const int count = max_count[blockIdx.x];
const int local_max_idx = max_idx[blockIdx.x];
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx - 1];
const int local_row_idx = rowidx[offset];
const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32;
const int warp_offset = (warp_id*32)*SPMM_ITEMS;
const int warp_offset = (warp_id * 32) * SPMM_ITEMS;
const int num_items = BITS == 8 ? 8 : 8;
int idx_col_B = warp_offset;
int local_idx_col_B_offset = 0;
......@@ -2157,10 +2017,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
// 128 byte loads per warp == 4 bytes per thread
// 2. Load A into registers
for(int j = 0; j < MAX_SPARSE_COUNT; j++)
{
local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
local_colidxA[j] = j < count ? colidx[offset+j] : 0;
for (int j = 0; j < MAX_SPARSE_COUNT; j++) {
local_valA[j] = j < count ? values[offset + j] : __float2half(0.0f);
local_colidxA[j] = j < count ? colidx[offset + j] : 0;
}
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
......@@ -2169,124 +2028,119 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
// added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__ half smem_dequant_stats[SMEM_SIZE];
while (idx_col_B < colsB) {
while(idx_col_B < colsB)
{
if(dequant_stats != NULL)
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
if (dequant_stats != NULL) {
for (int i = threadIdx.x; i < SMEM_SIZE; i += blockDim.x)
if ((idx_col_B + i - local_idx_col_B_offset) < colsB)
smem_dequant_stats[i] = dequant_stats[idx_col_B + i - local_idx_col_B_offset];
__syncthreads();
}
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j++)
#pragma unroll SPMM_ITEMS
for (int j = 0; j < SPMM_ITEMS; j++)
local_valC[j] = 0.0f;
#pragma unroll
for(int i = 0; i < count; i++)
{
#pragma unroll
for (int i = 0; i < count; i++) {
// 3. each warp loads all required rows of B but each warp is offset by k
int row_offset = colsB*local_colidxA[i];
int row_offset = colsB * local_colidxA[i];
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
{
#pragma unroll SPMM_ITEMS
for (int j = 0; j < SPMM_ITEMS; j += num_items) {
// 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
if(idx >= colsB){ break; }
if((idx+num_items < colsB))
{
if(BITS == 8)
reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
else
reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
int idx = idx_col_B + (warp_idx * SPMM_ITEMS) + j;
if (idx >= colsB) {
break;
}
if ((idx + num_items < colsB)) {
if (BITS == 8)
reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] =
reinterpret_cast<float2*>(B)[(row_offset + idx) / num_items];
else
{
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
if(idx+k < colsB)
local_valsB[k] = B[row_offset+idx+k];
reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] =
reinterpret_cast<float4*>(B)[(row_offset + idx) / num_items];
} else {
#pragma unroll num_items
for (int k = 0; k < num_items; k++)
if (idx + k < colsB)
local_valsB[k] = B[row_offset + idx + k];
else
local_valsB[k] = 0.0f;
}
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
{
if(BITS == 8 && dequant_stats != NULL)
#pragma unroll num_items
for (int k = 0; k < num_items; k++) {
if (BITS == 8 && dequant_stats != NULL)
// we do texture cache reads (__ldg) on dequant_stats which should be super fast
{
float valB = local_valsB[k];
float valA = local_valA[i];
if(valB != 0.0 && valA != 0.0)
local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA;
}
else
local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
if (valB != 0.0 && valA != 0.0)
local_valC[j + k] =
(float)local_valC[j + k] +
((float)smem_dequant_stats[idx + k - local_idx_col_B_offset]) * DENORM * valB * valA;
} else
local_valC[j + k] = (float)local_valC[j + k] + (float)local_valsB[k] * (float)local_valA[i];
}
}
}
int idx_row_C = (colsB*local_row_idx);
int idx_row_C = (colsB * local_row_idx);
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
{
//int idx_col_C = idx_col_B + (32*j) + warp_idx;
int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
#pragma unroll SPMM_ITEMS
for (int j = 0; j < SPMM_ITEMS; j += num_items) {
// int idx_col_C = idx_col_B + (32*j) + warp_idx;
int idx_col_C = idx_col_B + warp_idx * SPMM_ITEMS + j;
int idx_val = idx_col_C + idx_row_C;
if(idx_col_C +num_items < colsB)
{
if (idx_col_C + num_items < colsB) {
// load outputs to do inplace addition
reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];
reinterpret_cast<float4(&)[num_items / 4]>(local_valOut)[0] =
reinterpret_cast<float4*>(out)[idx_val / num_items];
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
#pragma unroll num_items
for (int k = 0; k < num_items; k++)
local_valC[(j / num_items) + k] = (float)local_valC[(j / num_items) + k] + (float)local_valOut[k];
reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
}
else
{
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
if(idx_col_C + k < colsB)
out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
reinterpret_cast<float4*>(out)[idx_val / num_items] =
reinterpret_cast<float4(&)[num_items]>(local_valC)[j / num_items];
} else {
#pragma unroll num_items
for (int k = 0; k < num_items; k++)
if (idx_col_C + k < colsB)
out[idx_val + k] = (float)out[idx_val + k] + (float)local_valC[j + k];
}
}
idx_col_B += blockDim.x*SPMM_ITEMS;
local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
idx_col_B += blockDim.x * SPMM_ITEMS;
local_idx_col_B_offset += blockDim.x * SPMM_ITEMS;
}
}
#define WARPS 3
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{
template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc) {
#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x *32;
int col_offset = blockIdx.x * 32;
const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
const int val_per_iter = blockDim.x-32;
const int batch_size_warps = (WARPS - 1) * 2;
const int val_per_iter = blockDim.x - 32;
T local_A[4];
T local_B[128];
const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);
const int b_tile_offset = (16 * 32 + 16);
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
__shared__ T smem_A[8 * 16 + (2 * 16 * (batch_size_warps - 1))];
__shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))];
//__shared__ T smem_C[8*32];
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
......@@ -2298,194 +2152,177 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
if (idx < K && warp_id < (WARPS - 1)) {
if (loaded_values == 0) {
local_A[0] = A[idx];
local_A[1] = A[idx+(1*val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)];
local_A[1] = A[idx + (1 * val_per_iter)];
local_A[2] = A[idx + (2 * val_per_iter)];
local_A[3] = A[idx + (3 * val_per_iter)];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
#pragma unroll 32
for (int col = 0; col < 32; col++) {
local_B[col] = B[(col_offset + col) * ldb + idx];
local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)];
local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)];
local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)];
}
loaded_values = 3;
}
else
{
} else {
if(loaded_values == 3)
{
if (loaded_values == 3) {
local_A[0] = local_A[1];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)];
}
else if(loaded_values == 2)
{
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (32)];
} else if (loaded_values == 2) {
local_A[0] = local_A[2];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)];
}
else
{
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (64)];
} else {
local_A[0] = local_A[3];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (96)];
}
loaded_values--;
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
local_B[col];
} else if (warp_id < (WARPS - 1)) {
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
{
// for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) {
idx = base_idx + threadIdx.x;
__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
//local_A[0] = A[idx];
if (idx < K && warp_id < (WARPS - 1)) {
// local_A[0] = A[idx];
//#pragma unroll 32
//for(int col = 0; col < 32; col++)
// #pragma unroll 32
// for(int col = 0; col < 32; col++)
// local_B[col] = B[(col_offset+col)*ldb+idx];
if(loaded_values == 0)
{
if (loaded_values == 0) {
local_A[0] = A[idx];
local_A[1] = A[idx+(1*val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)];
local_A[1] = A[idx + (1 * val_per_iter)];
local_A[2] = A[idx + (2 * val_per_iter)];
local_A[3] = A[idx + (3 * val_per_iter)];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
#pragma unroll 32
for (int col = 0; col < 32; col++) {
local_B[col] = B[(col_offset + col) * ldb + idx];
local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)];
local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)];
local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)];
}
loaded_values = 3;
}
else
{
} else {
if(loaded_values == 3)
{
if (loaded_values == 3) {
local_A[0] = local_A[1];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)];
}
else if(loaded_values == 2)
{
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (32)];
} else if (loaded_values == 2) {
local_A[0] = local_A[2];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)];
}
else
{
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (64)];
} else {
local_A[0] = local_A[3];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (96)];
}
loaded_values--;
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
local_B[col];
} else if (warp_id < (WARPS - 1)) {
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
if (warp_id == (WARPS - 1))
for (int k = 0; k < batch_size_warps; k++) {
wmma::load_matrix_sync(
a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16
); // 111 mu
wmma::load_matrix_sync(
b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16
); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}
__syncthreads();
if(warp_id != (WARPS-1)){ return; }
if (warp_id != (WARPS - 1)) {
return;
}
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;
ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
for (int k = 0; k < batch_size_warps; k++) {
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// 129 mu
if(warp_id == (WARPS-1))
if (warp_id == (WARPS - 1))
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
if(col_offset + warp_lane < M)
if (col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
#endif
}
template <typename T> __device__ void printnonzero(T *A, int num_values, const char * strval)
{
for(int i = 0; i < num_values; i++)
if((float)A[i] != 0.0)
template <typename T> __device__ void printnonzero(T* A, int num_values, const char* strval) {
for (int i = 0; i < num_values; i++)
if ((float)A[i] != 0.0)
printf("%s %i %f\n", strval, i, (float)A[i]);
}
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
) {
//// element-wise kernel
//// 1. Load batch x k into registers
......@@ -2507,17 +2344,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
//// 9. write outputs to matmul output matrix
#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x *32;
int col_offset = blockIdx.x * 32;
const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
const int batch_size_warps = (WARPS - 1) * 2;
T quant_map[16];
#pragma unroll 16
for(int i = 0; i < 16; i++)
#pragma unroll 16
for (int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
//__shared__ T quant_map[16*160];
......@@ -2525,20 +2362,19 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
T local_B[64];
unsigned char local_B_4bit[32];
const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);
const int b_tile_offset = (16 * 32 + 16);
__shared__ T smem_A[8*16 + (16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
__shared__ T smem_C[8*32];
__shared__ T smem_A[8 * 16 + (16 * (batch_size_warps - 1))];
__shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))];
__shared__ T smem_C[8 * 32];
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
for(int i = threadIdx.x; i < (8*32); i+=blockDim.x)
for (int i = threadIdx.x; i < (8 * 32); i += blockDim.x)
smem_C[i] = 0.0f;
__syncthreads();
......@@ -2547,305 +2383,287 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
if (idx < K && warp_id < (WARPS - 1)) {
if (loaded_values == 0) {
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
local_A[1] = A[idx + blockDim.x - 32];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B_4bit[col] = B[(col_offset + col) * ldb + idx];
loaded_values = 1;
}
else
{
} else {
local_A[0] = local_A[1];
loaded_values--;
#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
#pragma unroll 64
for (int col = 0; col < 64; col += 2) {
// local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
// local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
// local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
// local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
// local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
// local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0);
local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0);
// local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
// local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B[col] = quant_map[160 * (local_B_4bit[col / 2] >> 4) + warp_idx] * T(17.0);
local_B[col + 1] = quant_map[160 * (local_B_4bit[col / 2] & 0x0F) + warp_idx] * T(17.0);
}
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
local_B[col];
} else if (warp_id < (WARPS - 1)) {
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
//if(threadIdx.x == 0)
//printf("aa %i %i\n", idx, loaded_values);
// if(threadIdx.x == 0)
// printf("aa %i %i\n", idx, loaded_values);
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
{
// for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) {
idx = base_idx + threadIdx.x;
//if(threadIdx.x == 0)
//printf("%i %i\n", idx, loaded_values);
// if(threadIdx.x == 0)
// printf("%i %i\n", idx, loaded_values);
//__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
if (idx < K && warp_id < (WARPS - 1)) {
if (loaded_values == 0) {
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
local_A[1] = A[idx + blockDim.x - 32];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx];
#pragma unroll 32
for (int col = 0; col < 32; col++) {
local_B_4bit[col] = B[(col_offset + col) * ldb + idx];
local_B_4bit[col + 16] = B[(col_offset + col) * ldb + idx];
}
loaded_values = 1;
}
else
{
} else {
local_A[0] = local_A[1];
loaded_values--;
int absidx = (idx + col_offset)/blocksize;
int absidx = (idx + col_offset) / blocksize;
half local_absmax = __ldg(&(absmax[absidx]));
#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
#pragma unroll 64
for (int col = 0; col < 64; col += 2) {
// local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
// local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
// local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
// local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
// local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
// local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B[col] = quant_map[(local_B_4bit[col / 2] >> 4)] * T(absidx);
local_B[col + 1] = quant_map[(local_B_4bit[col / 2] & 0x0F)] * T(absidx);
}
//printnonzero<T>(local_B, 128, "");
// printnonzero<T>(local_B, 128, "");
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
local_B[col];
} else if (warp_id < (WARPS - 1)) {
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
if (warp_id == (WARPS - 1))
for (int k = 0; k < batch_size_warps; k++) {
wmma::load_matrix_sync(
a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16
); // 111 mu
wmma::load_matrix_sync(
b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16
); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}
__syncthreads();
//if(threadIdx.x == 0)
// if(threadIdx.x == 0)
//{
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
//}
if(warp_id != (WARPS-1)){ return; }
// }
if (warp_id != (WARPS - 1)) {
return;
}
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;
ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
//if(warp_lane == 0)
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
for (int k = 0; k < batch_size_warps; k++) {
// if(warp_lane == 0)
// printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// 129 mu
if(warp_id == (WARPS-1))
if (warp_id == (WARPS - 1))
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
//printnonzero<T>(smem_C, 32, "");
// printnonzero<T>(smem_C, 32, "");
if(col_offset + warp_lane < M)
if (col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_C[warp_lane];
#endif
}
#define num_values_4bit 32
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
{
template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
int lda, int ldb, int ldc, int blocksize
) {
// per threadblock:
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1x32 * 32x4 -> 1x4 outputs per thread block
typedef cub::WarpReduce<float> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32];
const int warp_idx = threadIdx.x / 32;
const int warp_lane = threadIdx.x % 32;
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
const int offset_B = ldb*row_B;
const int num_values_8bit = num_values_4bit/2;
const int row_B = (THREADS / 32) * blockIdx.x + warp_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit / 2;
float local_C = 0.0f;
unsigned char local_B_4bit[num_values_8bit];
T local_B[num_values_4bit/4];
T local_A[num_values_4bit/4];
T local_B[num_values_4bit / 4];
T local_A[num_values_4bit / 4];
__shared__ T quant_map[16];
T local_absmax = T(0.0f);
if (threadIdx.x < 16)
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x]));
//for(int i = threadIdx.x; i < 16; i++)
//quant_map[i] = T(__ldg(&datatype[i]));
// for(int i = threadIdx.x; i < 16; i++)
// quant_map[i] = T(__ldg(&datatype[i]));
__syncthreads();
// A: [1, K]
// B: [N, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
{
const int inner_idx_halved = inner_idx/2;
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
const int inner_idx_halved = inner_idx / 2;
// Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize));
const int absidx = ((2 * offset_B) + inner_idx) >> (31 - __clz(blocksize));
local_absmax = __ldg(&(absmax[absidx]));
if(row_B < M)
{
if((inner_idx_halved + num_values_8bit) < (K/2))
{
if (row_B < M) {
if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
// this is the most important for performance considerations
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
}
else
{
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
if((inner_idx_halved) + j < (K/2))
local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] =
reinterpret_cast<int4*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
} else {
#pragma unroll
for (int j = 0; j < (num_values_8bit); j++)
if ((inner_idx_halved) + j < (K / 2))
local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
else
local_B_4bit[j] = 0b01110111;
}
}
else
{
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
} else {
#pragma unroll
for (int j = 0; j < (num_values_8bit); j++)
local_B_4bit[j] = 0b01110111;
}
for(int i = 0; i < 4; i++)
{
#pragma unroll
for(int k = 0; k < num_values_8bit/4; k++)
{
#if BNB_BF16_AVAILABLE
local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax;
local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax;
#else
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int k = 0; k < num_values_8bit / 4; k++) {
#if BNB_BF16_AVAILABLE
local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
#else
// bf16 multipliation not supported
local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax);
local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax);
#endif
local_B[k * 2] =
T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * (float)local_absmax);
local_B[k * 2 + 1] =
T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * (float)local_absmax);
#endif
}
if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K)
{
if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
// this is also relatively important for performance
if(BITS==16)
{
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + i];
}
else
{
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1];
if (BITS == 16) {
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 4) + i];
} else {
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] =
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
}
}
else
#pragma unroll
for(int k = 0; k < num_values_4bit/4; k++)
if(inner_idx + (i*num_values_4bit/4) + k < K)
local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)];
} else
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++)
if (inner_idx + (i * num_values_4bit / 4) + k < K)
local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
else
local_A[k] = T(0.0f);
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
for(int k = 0; k < num_values_4bit/4; k++)
{
#if BNB_BF16_AVAILABLE
local_C += (float)(local_A[k]*local_B[k]);
#else
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++) {
#if BNB_BF16_AVAILABLE
local_C += (float)(local_A[k] * local_B[k]);
#else
// bf16 multipliation not supported
local_C += ((float)local_A[k]*(float)local_B[k]);
#endif
local_C += ((float)local_A[k] * (float)local_B[k]);
#endif
}
}
}
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
if(row_B < M && warp_lane == 0)
if (row_B < M && warp_lane == 0)
out[row_B] = T(local_C);
}
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n)
{
for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x))
{
switch(FUNC)
{
template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n) {
for (long i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += (blockDim.x * gridDim.x)) {
switch (FUNC) {
case FILL:
A[i] = (T)value;
break;
......@@ -2853,70 +2671,143 @@ template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long
A[i] = (T)i;
break;
case _MUL:
A[i] = A[i]*B[i];
A[i] = A[i] * B[i];
break;
}
}
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template __global__ void kfunc<float, FILL>(float *A, float *B, float value, long n);
template __global__ void kfunc<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
template __global__ void kfunc<float, ARANGE>(float *A, float *B, float value, long n);
template __global__ void kfunc<float, _MUL>(float *A, float *B, float value, long n);
template __global__ void kfunc<float, FILL>(float* A, float* B, float value, long n);
template __global__ void kfunc<unsigned char, FILL>(unsigned char* A, unsigned char* B, unsigned char value, long n);
template __global__ void kfunc<float, ARANGE>(float* A, float* B, float value, long n);
template __global__ void kfunc<float, _MUL>(float* A, float* B, float value, long n);
// these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
// template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 256>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 192>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 160>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 128>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
// template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 64>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 96>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
// these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
// template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 256>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 192>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 160>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 128>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
// template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 64>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 96>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void kgemm_4bit_inference<half, 96>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 128>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 160>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 256>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out,
int lda, int ldb, int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(
int M, int N, int K, __nv_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
__nv_bfloat16* out, int lda, int ldb, int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(
int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
float* out, int lda, int ldb, int ldc, int blocksize
);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kdequant_mm_int32_fp16<4, 512>(
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>( \
gtype * g, gtype * p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, \
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n \
);
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
......@@ -2932,8 +2823,11 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
template __global__ void kOptimizer32bit1State<gtype, oname>( \
gtype * g, gtype * p, float* state1, float* unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, const int step, \
const float lr, const float gnorm_scale, const bool skip_zeros, const int n \
);
MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
......@@ -2949,10 +2843,11 @@ MAKE_Optimizer32bit1State(ADAGRAD, float)
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float* state2, float *unorm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>( \
gtype * g, gtype * p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, \
const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, \
const int n \
);
MAKE_PreconditionOptimizer32bit2State(ADAM, float)
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
......@@ -2961,31 +2856,49 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16)
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADAM>(
float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template __global__ void kOptimizer32bit2State<half, ADAM>(
half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(
__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(
float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(
half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(
__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \
const float beta1, \
const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
const float weight_decay, \
const float gnorm_scale, \
const int n); \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>( \
gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, \
const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, \
float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n \
);
MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
......@@ -2997,16 +2910,12 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
MAKE_PreconditionStatic8bit1State(ADAGRAD, float)
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
float weight_decay, \
const float gnorm_scale, \
const int n); \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>( \
gtype * p, gtype* const g, unsigned char* state1, const float* unorm, const float max_unorm, \
const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, \
const float gnorm_scale, const int n \
);
MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
......@@ -3017,39 +2926,39 @@ MAKE_optimizerStatic8bit1State(LION, float)
MAKE_optimizerStatic8bit1State(ADAGRAD, half)
MAKE_optimizerStatic8bit1State(ADAGRAD, float)
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
float *unorm, \
const float beta1, const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
const float gnorm_scale, \
const int n); \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>( \
gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, \
unsigned char* __restrict__ const state2, float* unorm, const float beta1, const float beta2, const float eps, \
const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, \
float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n \
);
MAKE_PreconditionStatic8bit2State(ADAM, half)
MAKE_PreconditionStatic8bit2State(ADAM, float)
#define MAKE_optimizerStatic8bit2State(oname, gtype) \
template __global__ void kOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, \
const float gnorm_scale, \
const int n); \
template __global__ void kOptimizerStatic8bit2State<gtype, oname>( \
gtype * p, gtype* const g, unsigned char* state1, unsigned char* state2, const float* unorm, \
const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, \
const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, \
const int n \
);
MAKE_optimizerStatic8bit2State(ADAM, half)
MAKE_optimizerStatic8bit2State(ADAM, float)
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void
kPercentileClipping<float, 2048, 4>(float* __restrict__ g, float* gnorm_vec, int step, const int n);
template __global__ void
kPercentileClipping<half, 2048, 4>(half* __restrict__ g, float* gnorm_vec, int step, const int n);
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>( \
float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \
const int rand_offset, const int n \
);
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit)
......@@ -3119,24 +3028,41 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4)
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(
float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(
float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n
);
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* absmax1, float* absmax2, \
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>( \
gtype * p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, \
const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, \
float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \
);
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1)
......@@ -3146,14 +3072,11 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
const float beta1, const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* absmax1, \
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
gtype * p, gtype* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, \
const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, \
float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \
);
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1)
......@@ -3168,5 +3091,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
template __device__ void printnonzero<float>(float* A, int num_values, const char* strval);
template __device__ void printnonzero<half>(half* A, int num_values, const char* strval);
......@@ -9,116 +9,129 @@
#ifndef kernels
#define kernels
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha,
const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
const float weight_decay,
const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
float weight_decay, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n);
__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n);
template <typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
__global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* absmax1,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(
T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(
T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1,
const float beta2, const float eps, const float weight_decay, const int step, const float lr,
const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kPreconditionOptimizerStatic8bit1State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1,
const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1,
float* new_max1, const float weight_decay, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizerStatic8bit1State(
T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale,
const int n
);
template <typename T, int OPTIMIZER>
__global__ void kPreconditionOptimizerStatic8bit2State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2,
float* unorm, const float beta1, const float beta2, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizerStatic8bit2State(
T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2,
const float beta3, const float alpha, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps,
const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n);
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template <int ITEMS_PER_THREAD, int THREADS>
__global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);
template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
);
template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
int lda, int ldb, int ldc, int blocksize
);
template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n);
#endif
......@@ -5,37 +5,34 @@
#define NUM 4
#define NUM_BLOCK 4096
static inline MPSGraph* get_graph()
{
static inline MPSGraph* get_graph() {
static MPSGraph* cur = nil;
if(!cur) {
if (!cur) {
cur = [[MPSGraph alloc] init];
}
return cur;
}
static inline id<MTLDevice> get_device()
{
NSError *error = nil;
static inline id<MTLDevice> get_device() {
NSError* error = nil;
static id<MTLDevice> device = nil;
if(!device) {
if (!device) {
device = MTLCreateSystemDefaultDevice();
}
if(!device) {
if (!device) {
NSLog(@"Failed to get MPS device");
abort();
}
return device;
}
static inline id<MTLLibrary> get_library()
{
NSError *error = nil;
static inline id<MTLLibrary> get_library() {
NSError* error = nil;
static id<MTLLibrary> library = nil;
if(!library) {
if (!library) {
library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error];
}
if(!library) {
if (!library) {
NSLog(@"Failed to load bitsandbytes.metallib");
abort();
}
......@@ -44,20 +41,18 @@ static inline id<MTLLibrary> get_library()
/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n)
{
id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"];
return out;
id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0
dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out;
}*/
// MPSGraph function for quantize
extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n)
{
extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) {
id<MTLDevice> device = get_device();
id<MTLLibrary> library = get_library();
static id<MTLFunction> kernel = nil;
if(!kernel) {
if (!kernel) {
kernel = [library newFunctionWithName:@"quantize"];
if(!kernel) {
if (!kernel) {
NSLog(@"Failed to load bitsandbytes.metallib");
abort();
}
......
......@@ -3,170 +3,190 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <ops.cuh>
#include <kernels.cuh>
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
#include <cassert>
#include <common.h>
#include <cub/device/device_scan.cuh>
#include <kernels.cuh>
#include <limits>
#include <ops.cuh>
#define ERR_NOT_IMPLEMENTED 100
using namespace BinSearch;
using std::cout;
using std::endl;
void quantize(float *code, float *A, unsigned char *out, int n)
{
int num_blocks = n/1024;
void quantize(float* code, float* A, unsigned char* out, int n) {
int num_blocks = n / 1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream)
{
int num_blocks = n/1024;
void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
int num_blocks = n / 1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
{
int num_blocks = n/blocksize;
template <typename T, int STOCHASTIC, int DATA_TYPE>
void quantizeBlockwise(
float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
) {
int num_blocks = n / blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048)
if (blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE>
<<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024)
else if (blocksize == 1024)
kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512)
else if (blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 256)
else if (blocksize == 256)
kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128)
else if (blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64)
else if (blocksize == 64)
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream)
{
template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, cudaStream_t stream
) {
// printf("stream==%d\n",stream);
int num_blocks = n/blocksize;
int num_blocks = n / blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
if(DATA_TYPE > 0)
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n);
if (DATA_TYPE > 0)
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n);
else
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
int num_blocks = n/4096;
template <typename T, int OPTIMIZER>
void optimizer32bit(
T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, const float beta1,
const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step,
const float lr, const float gnorm_scale, bool skip_zeros, const int n
) {
int num_blocks = n / 4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
switch(OPTIMIZER)
{
switch (OPTIMIZER) {
case ADAM:
case ADEMAMIX:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
if (max_unorm > 0.0f) {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(
g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr,
gnorm_scale, skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
if (max_unorm > 0.0f) {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8>
<<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale,
skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale,
skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
if (max_unorm > 0.0f) {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8>
<<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
break;
}
}
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm,
float beta1, float beta2,
float eps, int step, float lr,
float* quantiles1, float* quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, int n)
{
int num_blocks = n/4096;
template <typename T, int OPTIMIZER>
void optimizerStatic8bit(
T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1,
float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n
) {
int num_blocks = n / 4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
if (max_unorm > 0.0f) {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
}
switch(OPTIMIZER)
{
switch (OPTIMIZER) {
case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(
p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1,
new_max2, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(
p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2,
max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(
p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1,
weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1,
weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(
p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
default:
......@@ -179,39 +199,23 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
#define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
T* p,
T* g,
unsigned char* state1,
unsigned char* state2,
float beta1,
float beta2,
float beta3,
float alpha,
float eps,
int step,
float lr,
float* quantiles1,
float* quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
bool skip_zeros,
int n
template <typename T, int OPTIMIZER>
void optimizerStatic8bitBlockwise(
T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, bool skip_zeros, int n
) {
int num_blocks = 0;
switch(OPTIMIZER)
{
switch (OPTIMIZER) {
case ADAM:
case ADEMAMIX:
num_blocks = n/BLOCKSIZE_2STATE;
num_blocks = n / BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
skip_zeros, n
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE>
<<<num_blocks, BLOCKSIZE_2STATE / NUM_2STATE>>>(
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1,
absmax2, weight_decay, gnorm_scale, skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
......@@ -219,88 +223,76 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
case RMSPROP:
case ADAGRAD:
case LION:
num_blocks = n/BLOCKSIZE_1STATE;
num_blocks = n / BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE>
<<<num_blocks, BLOCKSIZE_1STATE / NUM_1STATE>>>(
p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
}
}
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
int num_blocks = n/2048;
template <typename T> void percentileClipping(T* g, float* gnorm_vec, int step, const int n) {
int num_blocks = n / 2048;
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float)));
kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{
void gemmex(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
) {
const int falpha = 1;
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
const void* alpha = &falpha;
const void* beta = &fbeta;
cublasStatus_t status;
status = cublasGemmEx(context->m_handle,
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
m, n, k,
alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
C, CUDA_R_32I, ldc,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
status = cublasGemmEx(
context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k,
alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, C, CUDA_R_32I, ldc, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
if (status != CUBLAS_STATUS_SUCCESS)
{
if (status != CUBLAS_STATUS_SUCCESS) {
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
}
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount)
{
void strided_gemmex(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount
) {
const int falpha = 1;
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
const void* alpha = &falpha;
const void* beta = &fbeta;
cublasStatus_t status;
//cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k);
//printf("%i %i %i\n", lda,ldb,ldc);
//printf("%i %i %i\n", strideA, strideB, strideC);
//printf("%i\n", batchCount);
status = cublasGemmStridedBatchedEx(context->m_handle,
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
m, n, k,
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
if (status != CUBLAS_STATUS_SUCCESS)
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
// cout << transposeA << transposeB << endl;
// printf("%i %i %i\n", m,n,k);
// printf("%i %i %i\n", lda,ldb,ldc);
// printf("%i %i %i\n", strideA, strideB, strideC);
// printf("%i\n", batchCount);
}
status = cublasGemmStridedBatchedEx(
context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k,
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, C,
CUDA_R_32I, ldc, (long long int)strideC, batchCount, CUDA_R_32I, CUBLAS_GEMM_DEFAULT
);
int roundoff(int v, int d) {
return (v + d - 1) / d * d;
if (status != CUBLAS_STATUS_SUCCESS) {
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
}
int roundoff(int v, int d) { return (v + d - 1) / d * d; }
template<int ORDER> cublasLtOrder_t get_order()
{
switch(ORDER)
{
template <int ORDER> cublasLtOrder_t get_order() {
switch (ORDER) {
case ROW:
return CUBLASLT_ORDER_ROW;
break;
......@@ -329,11 +321,8 @@ template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();
template<int ORDER> int get_leading_dim(int dim1, int dim2)
{
switch(ORDER)
{
template <int ORDER> int get_leading_dim(int dim1, int dim2) {
switch (ORDER) {
case ROW:
return dim2;
break;
......@@ -342,14 +331,14 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
break;
case COL32:
// 32*row tiles
return dim1*32;
return dim1 * 32;
break;
case COL_TURING:
return 32*roundoff(dim1, 8);
return 32 * roundoff(dim1, 8);
break;
case COL_AMPERE:
// 32*32 tiles
return 32*roundoff(dim1, 32);
return 32 * roundoff(dim1, 32);
break;
default:
return 0;
......@@ -357,15 +346,10 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
}
}
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
cublasLtHandle_t ltHandle,
int m, int n, int k,
const int8_t * A,
const int8_t * B,
void * C,
float * row_scale,
int lda, int ldb, int ldc,
cudaStream_t stream
template <int DTYPE_OUT, int SCALE_ROWS>
int igemmlt(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
// Calculate C = A^T @ B, in col-major layout.
......@@ -393,17 +377,14 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
// Default layout order is col major
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT)));
has_error |=
checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT)));
if (DTYPE_OUT == 32) {
int alpha = 1, beta = 0;
has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc,
&alpha, A, aDesc,
B, bDesc, &beta,
(int32_t*)C, cDesc,
(int32_t*)C, cDesc,
NULL, NULL, 0, stream
ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, NULL, NULL,
0, stream
));
} else {
// This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.
......@@ -411,29 +392,18 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
if (!SCALE_ROWS) {
float alpha = 1.0f, beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc,
&alpha, A, aDesc,
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL,
NULL, 0, stream
));
} else {
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
float beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(
matmulDesc,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointerMode,
sizeof(alphaVec)
matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, sizeof(alphaVec)
));
has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc,
row_scale, A, aDesc,
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
ltHandle, matmulDesc, row_scale, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL,
NULL, 0, stream
));
}
}
......@@ -443,30 +413,33 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc));
has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
if(has_error == 1)
if (has_error == 1)
printf("error detected");
return has_error;
}
int fill_up_to_nearest_multiple(int value, int multiple)
{
int fill_up_to_nearest_multiple(int value, int multiple) {
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, cudaStream_t stream)
{
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
) {
const int threads = 512;
const int num_per_thread = 4;
const int num_per_block = threads * num_per_thread;
const int n = numRows*numCols;
const int n = numRows * numCols;
const int num_blocks = (n + num_per_block - 1) / num_per_block;
kdequant_mm_int32_fp16<num_per_thread, threads><<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
kdequant_mm_int32_fp16<num_per_thread, threads>
<<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
) {
if (threshold == 0.0) {
kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
} else {
......@@ -475,7 +448,7 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
if (threshold == 0.0)
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
else
......@@ -483,94 +456,101 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
void spmm_coo(
cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
) {
cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC;
float alpha = 1.0f;
float beta = 0.0f;
void *dBuffer = NULL;
void* dBuffer = NULL;
size_t bufferSize = 0;
CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
A_rowidx, A_colidx, A_vals,
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
CHECK_CUSPARSE(cusparseCreateCoo(
&descA, A_rows, A_cols, A_nnz, A_rowidx, A_colidx, A_vals, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
CUDA_R_16F
));
// Create dense matrix C
CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
CHECK_CUSPARSE(cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, CUDA_R_16F, CUSPARSE_ORDER_ROW));
// Create dense matrix B
if(transposed_B)
{
if (transposed_B) {
int tmp = A_cols;
A_cols = B_cols;
B_cols = tmp;
}
CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
CHECK_CUSPARSE(cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, CUDA_R_16F, CUSPARSE_ORDER_ROW));
// allocate an external buffer if needed
CHECK_CUSPARSE( cusparseSpMM_bufferSize(
handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
CHECK_CUSPARSE(cusparseSpMM_bufferSize(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta,
descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize
));
CUDA_CHECK_RETURN(cudaMalloc(&dBuffer, bufferSize));
// execute SpMM
CHECK_CUSPARSE( cusparseSpMM(handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));
CHECK_CUSPARSE(cusparseSpMM(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta,
descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, dBuffer
));
// destroy matrix/vector descriptors
CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
CHECK_CUSPARSE(cusparseDestroySpMat(descA));
CHECK_CUSPARSE(cusparseDestroyDnMat(descB));
CHECK_CUSPARSE(cusparseDestroyDnMat(descC));
CUDA_CHECK_RETURN(cudaFree(dBuffer));
}
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{
template <typename T, int BITS>
void spmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{
template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits) {
int num_blocks = (m+31)/32;
int num_blocks = (m + 31) / 32;
if(bits == 32)
gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16)
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
if (bits == 32)
gemm_device<T, 32, 32><<<num_blocks, 32, 0, 0>>>(m, n, k, A, B, out, lda, ldb, ldc);
if (bits == 16)
gemm_device<T, 16, 160><<<num_blocks, 160, 0, 0>>>(m, n, k, A, B, out, lda, ldb, ldc);
}
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
template <typename T>
void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
) {
int num_blocks = (m+31)/32;
int num_blocks = (m + 31) / 32;
kgemm_4bit_inference<T, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
kgemm_4bit_inference<T, 96><<<num_blocks, 96, 0, 0>>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{
template <typename T, int BITS>
void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, cudaStream_t stream
) {
int num_blocks = (m+3)/4;
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
int num_blocks = (m + 3) / 4;
kgemm_4bit_inference_naive<T, 128, BITS>
<<<num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
{
template <typename T, int FUNC> void func(T* A, T* B, T value, long n) {
int threads = 512;
int blocks = n/threads;
int blocks = n / threads;
blocks = n % threads == 0 ? blocks : blocks + 1;
blocks = blocks > 65535 ? 65535 : blocks;
kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n);
......@@ -581,103 +561,154 @@ template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
// TEMPLATE DEFINITIONS
//==============================================================
template void func<float, FILL>(float *A, float *B, float value, long n);
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
template void func<float, ARANGE>(float *A, float *B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
template void func<float, FILL>(float* A, float* B, float value, long n);
template void func<unsigned char, FILL>(unsigned char* A, unsigned char* B, unsigned char value, long n);
template void func<float, ARANGE>(float* A, float* B, float value, long n);
template void func<float, _MUL>(float* A, float* B, float value, long n);
template void gemm_4bit_inference<half>(
int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
);
template void gemm_4bit_inference_naive<half, 16>(
int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
);
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(
int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
);
template void gemm_4bit_inference_naive<float, 32>(
int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
);
// template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc,
// int bits);
template void gemm_host<half>(int m, int n, int k, half* A, half* B, half* out, int lda, int ldb, int ldc, int bits);
template void spmm_coo_very_sparse_naive<half, 16>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);
template void spmm_coo_very_sparse_naive<signed char, 8>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);
template int igemmlt<32, 0>(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
);
template int igemmlt<8, 0>(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
);
template int igemmlt<8, 1>(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
);
template void quantizeBlockwise<half, 1, General8bit>(
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<half, 0, General8bit>(
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<half, 0, FP4>(
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<half, 0, NF4>(
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<float, 1, General8bit>(
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<float, 0, General8bit>(
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<float, 0, FP4>(
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<float, 0, NF4>(
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const int n
);
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const int n
);
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const int n
);
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const int n
);
template void dequantizeBlockwise<float, General8bit>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<float, FP4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<float, NF4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<half, General8bit>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<half, FP4>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<half, NF4>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, FP4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, NF4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(RMSPROP, __nv_bfloat16)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, half)
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, float)
template void optimizer32bit<gtype, name>( \
gtype * g, gtype * p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, \
const int n \
);
MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(
MOMENTUM, __nv_bfloat16
) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, __nv_bfloat16) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, half) MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, float)
#define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, \
const float gnorm_scale, int n); \
MAKE_optimizerStatic8bit(ADAM, half)
MAKE_optimizerStatic8bit(ADAM, float)
MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float)
MAKE_optimizerStatic8bit(LION, half)
MAKE_optimizerStatic8bit(LION, float)
MAKE_optimizerStatic8bit(ADAGRAD, half)
MAKE_optimizerStatic8bit(ADAGRAD, float)
template void optimizerStatic8bit<gtype, name>( \
gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
const float gnorm_scale, int n \
);
MAKE_optimizerStatic8bit(ADAM, half) MAKE_optimizerStatic8bit(ADAM, float) MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit(
RMSPROP, half
) MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(LION, half) MAKE_optimizerStatic8bit(LION, float) MAKE_optimizerStatic8bit(ADAGRAD, half) MAKE_optimizerStatic8bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
template void optimizerStatic8bitBlockwise<gtype, optim_name>( \
gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
);
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
......@@ -696,8 +727,8 @@ MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(float* g, float* gnorm_vec, int step, const int n);
template void percentileClipping(half* g, float* gnorm_vec, int step, const int n);
template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2);
......
......@@ -3,41 +3,41 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#ifndef ops_H
#define ops_H
#include <assert.h>
#include <cstdint>
#include <stdio.h>
#include <iostream>
#include <assert.h>
#include <stdio.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <cusparse.h>
#include <vector>
#include <functional>
#include <vector>
#define CUDA_CHECK_RETURN(value) { \
#define CUDA_CHECK_RETURN(value) \
{ \
cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
fprintf(stderr, "Error %s at line %d in file %s\n", cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
} \
}
#define CHECK_CUSPARSE(value) { \
#define CHECK_CUSPARSE(value) \
{ \
cusparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
fprintf( \
stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__ \
); \
exit(1); \
} }
} \
}
inline void checkCudaStatus(cudaError_t status) {
if (status != cudaSuccess) {
......@@ -49,19 +49,17 @@ inline void checkCudaStatus(cudaError_t status) {
inline int checkCublasStatus(cublasStatus_t status) {
if (status != CUBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %d\n", status);
//throw std::logic_error("cuBLAS API failed");
// throw std::logic_error("cuBLAS API failed");
return 1;
}
return 0;
}
typedef enum Operations_t
{
typedef enum Operations_t {
ksmul = 0,
} Operations_t;
typedef enum Optimizer_t
{
typedef enum Optimizer_t {
ADAM = 0,
MOMENTUM = 1,
RMSPROP = 2,
......@@ -71,8 +69,7 @@ typedef enum Optimizer_t
ADEMAMIX = 6
} Optimizer_t;
typedef enum Transform_t
{
typedef enum Transform_t {
ROW = 0,
COL = 1,
COL32 = 2,
......@@ -80,109 +77,135 @@ typedef enum Transform_t
COL_AMPERE = 4,
} Transform_t;
typedef enum DataType_t
{
typedef enum DataType_t {
General8bit = 0,
FP4 = 1,
NF4 = 2,
} DataType_t;
typedef enum Funcs_t
{
typedef enum Funcs_t {
FILL = 0,
ARANGE = 1,
_MUL = 2,
} Funcs_t;
class Context
{
class Context {
public:
cublasHandle_t m_handle;
Context()
{
Context() {
cublasHandle_t handle;
cublasCreate_v2(&handle);
m_handle = handle;
}
};
class ContextLt
{
class ContextLt {
public:
cublasLtHandle_t m_handle;
ContextLt()
{
ContextLt() {
cublasLtHandle_t handle;
cublasLtCreate(&handle);
m_handle = handle;
}
};
class ContextCusparse
{
class ContextCusparse {
public:
cusparseHandle_t m_handle;
ContextCusparse()
{
ContextCusparse() {
cusparseHandle_t handle;
cusparseCreate(&handle);
m_handle = handle;
}
};
void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream);
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay,
int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm,
float beta1, float beta2,
float eps, int step, float lr,
float* quantiles1, float* quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros, int n);
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount);
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
void quantize(float* code, float* A, unsigned char* out, int n);
void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream);
template <typename T, int STOCHASTIC, int DATA_TYPE>
void quantizeBlockwise(
float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, cudaStream_t stream
);
template <typename T, int OPTIMIZER>
void optimizer32bit(
T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2,
float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale,
bool skip_zeros, int n
);
template <typename T, int OPTIMIZER>
void optimizerStatic8bit(
T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1,
float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n
);
template <typename T, int OPTIMIZER>
void optimizerStatic8bitBlockwise(
T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, bool skip_zeros, int n
);
template <typename T> void percentileClipping(T* g, float* gnorm_vec, int step, const int n);
void gemmex(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
);
void strided_gemmex(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount
);
template <int DTYPE_OUT, int SCALE_ROWS>
int igemmlt(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
);
void cutlass_igemm(
bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc
);
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
);
void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
);
void spmm_coo(
cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
);
template <typename T, int BITS>
void spmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);
void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits);
template <typename T>
void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
);
template <typename T, int BITS>
void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, cudaStream_t stream
);
template <typename T, int FUNC> void func(T* A, T* B, T value, long n);
#endif
......@@ -20,39 +20,60 @@
#if BUILD_CUDA
//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
// void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
void gemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) {
gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16);
}
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void gemm_4bit_inference(
int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
) {
gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); }
void gemm_4bit_inference_naive_fp16(
int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); }
void gemm_4bit_inference_naive_bf16(
int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<__nv_bfloat16, 16>(
m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream
);
}
void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); }
void gemm_4bit_inference_naive_fp32(
int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
void fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { func<ctype, FUNC>(A, B, value, n); }
MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_grad_##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
void fname##32bit_grad_##gbits( \
gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
const float weight_decay, const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n \
) { \
optimizer32bit<gtype, oname>( \
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \
lr, gnorm_scale, skip_zeros, n \
); \
}
MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
......@@ -70,19 +91,18 @@ MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32)
MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16)
MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \
{ \
optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
} \
void fname##_static_8bit_grad_##gbits( \
gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
float gnorm_scale, int n \
) { \
optimizerStatic8bit<gtype, oname>( \
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \
max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \
); \
}
MAKE_FUNC8(adam, ADAM, float, 32)
MAKE_FUNC8(adam, ADAM, half, 16)
......@@ -94,10 +114,16 @@ MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
void fname##_8bit_blockwise_grad_##gbits( \
gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
) { \
optimizerStatic8bitBlockwise<gtype, optim_name>( \
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \
weight_decay, gnorm_scale, skip_zeros, n \
); \
}
MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
......@@ -118,239 +144,511 @@ MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
void percentileClipping_g32(float* g, float* gnorm_vec, int step, const int n) {
percentileClipping<float>(g, gnorm_vec, step, n);
}
void percentileClipping_g16(half* g, float* gnorm_vec, int step, const int n) {
percentileClipping<half>(g, gnorm_vec, step, n);
}
void quantizeBlockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16_fp4(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16_nf4(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
void quantizeBlockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_fp4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_nf4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void dequantizeBlockwise_fp16(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n, stream); } \
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n, stream); } \
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n, stream); } \
void dequantizeBlockwise_fp16_fp4(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_fp16_nf4(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_fp32(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
void dequantizeBlockwise_fp32_fp4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp32_nf4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16_fp4(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16_nf4(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
int igemmlt_32(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
int igemmlt_8(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
int igemmlt_8_rowscale(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void spmm_coo_very_sparse_naive_fp16(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive<half, 16>(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void spmm_coo_very_sparse_naive_int8(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive<signed char, 8>(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
#endif
extern "C"
{
extern "C" {
#if BUILD_CUDA
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); }
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); }
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); }
void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); }
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, fp16)
MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
MAKE_CFUNC32(momentum, float, 32)
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, fp32)
MAKE_CFUNC32(lion, half, fp16)
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)
MAKE_CFUNC32(ademamix, float, fp32)
MAKE_CFUNC32(ademamix, half, fp16)
MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16)
#define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \
{ \
name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
} \
MAKE_CFUNC8(adam, float, 32)
MAKE_CFUNC8(adam, half, 16)
MAKE_CFUNC8(momentum, float, 32)
MAKE_CFUNC8(momentum, half, 16)
MAKE_CFUNC8(rmsprop, float, 32)
MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32)
MAKE_CFUNC8(lion, half, 16)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
{ fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{ gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); }
void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long strideA, long strideB, long strideC, int batchCount)
{ strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }
Context *get_context(){ return new Context(); }
ContextCusparse *get_cusparse(){ return new ContextCusparse(); }
int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); }
void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
dequantize(code, A, out, n, stream);
}
void cdequantize_blockwise_fp16_fp4(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp16(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp16_nf4(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream);
}
void cquantize_blockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n);
}
void cquantize_blockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n);
}
void cquantize_blockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n);
}
void cquantize_blockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n);
}
void cquantize_blockwise_fp32_fp4(
float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n);
}
void cquantize_blockwise_fp32_nf4(
float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n);
}
void cdequantize_blockwise_fp32(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp32_fp4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp32_nf4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream);
}
void cquantize_blockwise_bf16(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n);
}
void cquantize_blockwise_bf16_fp4(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n);
}
void cquantize_blockwise_bf16_nf4(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n);
}
void cdequantize_blockwise_bf16(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_bf16_fp4(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_bf16_nf4(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream);
}
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_grad_##gbits( \
gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, \
const int n \
) { \
name##32bit_grad_##gbits( \
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \
lr, gnorm_scale, skip_zeros, n \
); \
}
int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, fp16)
MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
MAKE_CFUNC32(momentum, float, 32)
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, fp32)
MAKE_CFUNC32(lion, half, fp16)
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)
MAKE_CFUNC32(ademamix, float, fp32)
MAKE_CFUNC32(ademamix, half, fp16)
MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16)
#define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_grad_##gbits( \
gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
float gnorm_scale, int n \
) { \
name##_static_8bit_grad_##gbits( \
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \
max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \
); \
}
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
MAKE_CFUNC8(adam, float, 32)
MAKE_CFUNC8(adam, half, 16)
MAKE_CFUNC8(momentum, float, 32)
MAKE_CFUNC8(momentum, half, 16)
MAKE_CFUNC8(rmsprop, float, 32)
MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32)
MAKE_CFUNC8(lion, half, 16)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_grad_##gbits( \
gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
) { \
fname##_8bit_blockwise_grad_##gbits( \
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \
weight_decay, gnorm_scale, skip_zeros, n \
); \
}
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); }
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
void cpercentile_clipping_g32(float* g, float* gnorm_vec, int step, const int n) {
percentileClipping_g32(g, gnorm_vec, step, n);
}
void cpercentile_clipping_g16(half* g, float* gnorm_vec, int step, const int n) {
percentileClipping_g16(g, gnorm_vec, step, n);
}
void cigemm(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
) {
gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc);
}
void cbatched_igemm(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc, long strideA, long strideB, long strideC, int batchCount
) {
strided_gemmex(
context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount
);
}
Context* get_context() { return new Context(); }
ContextCusparse* get_cusparse() { return new ContextCusparse(); }
int cigemmlt_32(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_32((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int cigemmlt_8(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_8((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int cigemmlt_8_rowscale(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_8_rowscale((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
void cdequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
) {
dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream);
}
void cget_row_stats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
getRowStats(A, rowStats, threshold, rows, cols, stream);
}
void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
}
void cint8_vector_quant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
) {
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
}
}
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{ spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
void cspmm_coo(
ContextCusparse* context, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
) {
spmm_coo(
(cusparseHandle_t)context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C,
transposed_B
);
}
void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void cspmm_coo_very_sparse_naive_fp16(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive_fp16(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void cspmm_coo_very_sparse_naive_int8(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive_int8(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
// void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }
void cgemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) {
gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc);
}
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void cgemm_4bit_inference(
int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
) {
gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
void *cget_managed_ptr(size_t bytes)
{
void *ptr;
void* cget_managed_ptr(size_t bytes) {
void* ptr;
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));
CUDA_CHECK_RETURN(cudaPeekAtLastError());
return ptr;
}
}
void cprefetch(void *ptr, size_t bytes, int device)
{
void cprefetch(void* ptr, size_t bytes, int device) {
int hasPrefetch = 0;
CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead
if (hasPrefetch == 0) return;
CUDA_CHECK_RETURN(
cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)
); // 40ns overhead
if (hasPrefetch == 0)
return;
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
}
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void c##fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { fname##_##type_name(A, B, value, n); }
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); }
void cgemm_4bit_inference_naive_fp16(
int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); }
void cgemm_4bit_inference_naive_bf16(
int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{ gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); }
void cgemm_4bit_inference_naive_fp32(
int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_cpu_fp32(
float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n
) {
quantize_cpu(code, A, absmax, out, blocksize, n);
}
void cdequantize_blockwise_cpu_fp32(
float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n
) {
dequantize_cpu(code, A, absmax, out, blocksize, n);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment