Unverified Commit 4955d136 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Apply clang-format rules (#1678)

parent 61db0859
...@@ -26,10 +26,12 @@ void quantize_block(const quantize_block_args& args) { ...@@ -26,10 +26,12 @@ void quantize_block(const quantize_block_args& args) {
if (idx < 255) { if (idx < 255) {
float dist_left = fabs(normed_value - (args.code[idx])); float dist_left = fabs(normed_value - (args.code[idx]));
float dist_right = fabs(normed_value - (args.code[idx + 1])); float dist_right = fabs(normed_value - (args.code[idx + 1]));
if (dist_right < dist_left) { idx += 1; } if (dist_right < dist_left) {
idx += 1;
}
} }
// 5. store index // 5. store index
args.out[i] = (unsigned char) idx; args.out[i] = (unsigned char)idx;
} }
} }
...@@ -28,7 +28,8 @@ ...@@ -28,7 +28,8 @@
// The maximum number of resident threads per SM varies by arch. // The maximum number of resident threads per SM varies by arch.
// For A100/H100 and all prior to Turing, it is 2048, which allows // For A100/H100 and all prior to Turing, it is 2048, which allows
// for 2 full blocks of 1024 threads per SM. // for 2 full blocks of 1024 threads per SM.
// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability // Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
#if __CUDA_ARCH__ == 750 #if __CUDA_ARCH__ == 750
#define BNB_MAX_THREADS_PER_SM 1024 #define BNB_MAX_THREADS_PER_SM 1024
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 #elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
......
...@@ -5,21 +5,18 @@ ...@@ -5,21 +5,18 @@
using namespace BinSearch; using namespace BinSearch;
#define BLOCK_SIZE 16384
struct quantize_block_args { struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher; BinAlgo<Scalar, float, Direct2>* bin_searcher;
float *code; float* code;
float *A; float* A;
float *absmax; float* absmax;
unsigned char *out; unsigned char* out;
long long block_end; long long block_end;
long long block_idx; long long block_idx;
long long threadidx; long long threadidx;
long long blocksize; long long blocksize;
}; };
void quantize_block(const quantize_block_args& args); void quantize_block(const quantize_block_args& args);
#endif #endif
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
using namespace BinSearch; using namespace BinSearch;
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n) {
for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items; long long block_end = block_idx + valid_items;
...@@ -13,8 +13,7 @@ void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, lo ...@@ -13,8 +13,7 @@ void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, lo
} }
} }
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) {
{
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f; code[0] = -1.0f;
...@@ -28,15 +27,13 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long ...@@ -28,15 +27,13 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
int thread_wave_size = 256; int thread_wave_size = 256;
// we chunk the threads into waves of 256 since the max limit is // we chunk the threads into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) {
{
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
std::vector<std::thread> threads(valid_chunks); std::vector<std::thread> threads(valid_chunks);
std::vector<quantize_block_args> args(valid_chunks); std::vector<quantize_block_args> args(valid_chunks);
int chunks_processed = 0; int chunks_processed = 0;
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) {
{
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items; long long block_end = block_idx + valid_items;
...@@ -53,11 +50,12 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long ...@@ -53,11 +50,12 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
chunks_processed += 1; chunks_processed += 1;
if(chunks_processed == valid_chunks){ break; } if (chunks_processed == valid_chunks) {
break;
}
} }
for (int i = 0; i < valid_chunks; i++) for (int i = 0; i < valid_chunks; i++)
threads[i].join(); threads[i].join();
} }
} }
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <iostream> #include <iostream>
#include <stdio.h> #include <stdio.h>
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n);
#endif #endif
...@@ -3,26 +3,42 @@ ...@@ -3,26 +3,42 @@
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#include "kernels.cuh"
#include "common.cuh" #include "common.cuh"
#include <cuda_fp16.h> #include "kernels.cuh"
#include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_discontinuity.cuh> #include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_store.cuh> #include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cuda_fp16.h>
#include <math_constants.h> #include <math_constants.h>
#include <mma.h> #include <mma.h>
#define HLF_MAX 65504 #define HLF_MAX 65504
#define TH 1024 #define TH 1024
#define NUM 4 #define NUM 4
#define NUM_BLOCK 4096 #define NUM_BLOCK 4096
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; __device__ static float nf4_data[16] = {
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0
};
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
__device__ float atomicMax(float* address, float val) { __device__ float atomicMax(float* address, float val) {
...@@ -30,42 +46,35 @@ __device__ float atomicMax(float* address, float val) { ...@@ -30,42 +46,35 @@ __device__ float atomicMax(float* address, float val) {
int old = *address_as_i, assumed; int old = *address_as_i, assumed;
do { do {
assumed = old; assumed = old;
old = atomicCAS( old = atomicCAS(reinterpret_cast<int*>(address), assumed, __float_as_int(fmaxf(val, __int_as_float(assumed))));
reinterpret_cast<int*>(address), assumed,
__float_as_int(fmaxf(val, __int_as_float(assumed))));
} while (assumed != old); } while (assumed != old);
return __int_as_float(old); return __int_as_float(old);
} }
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) {
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0100) == 4) // 0 if ((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01 if ((val & 0b0010) == 2) // 01
if((val & 0b0001) == 1) // 111 if ((val & 0b0001) == 1) // 111
return 0.25000000f*absmax*sign; // 1111 return 0.25000000f * absmax * sign; // 1111
else
return 0.16666667f*absmax*sign; // 1110
else
if((val & 0b0001) == 1) // 110
return 0.50000000f*absmax*sign; // 1101
else else
return 0.33333333f*absmax*sign; // 1100 return 0.16666667f * absmax * sign; // 1110
else if ((val & 0b0001) == 1) // 110
return 0.50000000f * absmax * sign; // 1101
else else
if((val & 0b0010) == 2) //10 return 0.33333333f * absmax * sign; // 1100
if((val & 0b0001) == 1) // 101 else if ((val & 0b0010) == 2) // 10
return 1.00000000f*absmax*sign; // 1011 if ((val & 0b0001) == 1) // 101
return 1.00000000f * absmax * sign; // 1011
else else
return 0.66666667f*absmax*sign; // 1010 return 0.66666667f * absmax * sign; // 1010
else if ((val & 0b0001) == 1) // 100
return 5.208333333e-03f * absmax * sign; // 1001
else else
if((val & 0b0001) == 1) // 100 return 0.00000000f * absmax * sign; // 1000
return 5.208333333e-03f*absmax*sign; // 1001
else
return 0.00000000f*absmax*sign; // 1000
} }
__device__ unsigned char dQuantizeFP4(float x) __device__ unsigned char dQuantizeFP4(float x) {
{
// FP4 with bias of 3 // FP4 with bias of 3
// first bit is a sign // first bit is a sign
// subnormals // subnormals
...@@ -78,7 +87,6 @@ __device__ unsigned char dQuantizeFP4(float x) ...@@ -78,7 +87,6 @@ __device__ unsigned char dQuantizeFP4(float x)
// 0b010 = 8 // 0b010 = 8
// 0b011 = 12 // 0b011 = 12
// we do a binary search // we do a binary search
// the pivots are divided by 12 (the FP4 absmax) // the pivots are divided by 12 (the FP4 absmax)
// since we assume input data is in [-1.0, 1.0] // since we assume input data is in [-1.0, 1.0]
...@@ -89,148 +97,124 @@ __device__ unsigned char dQuantizeFP4(float x) ...@@ -89,148 +97,124 @@ __device__ unsigned char dQuantizeFP4(float x)
int sign = x < 0 ? 0b1000 : 0b0000; int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x); x = fabsf(x);
if(x > 0.29166667f) if (x > 0.29166667f)
if( x > 0.583333f) if (x > 0.583333f)
if( x > 0.8333333f) if (x > 0.8333333f)
return 0b0011+sign; return 0b0011 + sign;
else
return 0b0010+sign;
else else
if(x > 0.4166667f) return 0b0010 + sign;
return 0b101+sign; else if (x > 0.4166667f)
return 0b101 + sign;
else else
return 0b100+sign; return 0b100 + sign;
else if (x > 0.0859375f)
if (x > 0.20833333f)
return 0b0111 + sign;
else else
if(x > 0.0859375f) return 0b0110 + sign;
if(x > 0.20833333f) else if (x > 0.00260417f)
return 0b0111+sign; return 0b0001 + sign;
else else
return 0b0110+sign; return 0b0000 + sign;
else
if(x > 0.00260417f)
return 0b0001+sign;
else
return 0b0000+sign;
} }
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) __device__ __forceinline__ float dDequantizeNF4(unsigned char val) {
{
// the values for this tree was generated by test_normal_map_tree // the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py // in the file tests/test_functional.py
if((val & 0b1000) == 8) if ((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1 if ((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11 if ((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111 if ((val & 0b0001) == 1) // 111
return 1.0f; return 1.0f;
else else
return 0.7229568362236023f; return 0.7229568362236023f;
else else if ((val & 0b0001) == 1) // 110
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f; return 0.5626170039176941f;
else else
return 0.44070982933044434f; return 0.44070982933044434f;
else else if ((val & 0b0010) == 2) // 10
if((val & 0b0010) == 2) //10 if ((val & 0b0001) == 1) // 101
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f; return 0.33791524171829224f;
else else
return 0.24611230194568634f; return 0.24611230194568634f;
else else if ((val & 0b0001) == 1) // 100
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f; return 0.16093020141124725f;
else else
return 0.07958029955625534f; return 0.07958029955625534f;
else else if ((val & 0b0100) == 4) // 0
if((val & 0b0100) == 4) // 0 if ((val & 0b0010) == 2) // 01
if((val & 0b0010) == 2) //01 if ((val & 0b0001) == 1) // 011
if((val & 0b0001) == 1) // 011
return 0.0f; return 0.0f;
else else
return -0.09105003625154495f; return -0.09105003625154495f;
else else if ((val & 0b0001) == 1) // 010
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f; return -0.18477343022823334f;
else else
return -0.28444138169288635f; return -0.28444138169288635f;
else else if ((val & 0b0010) == 2) // 00
if((val & 0b0010) == 2) //00 if ((val & 0b0001) == 1) // 001
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f; return -0.39491748809814453f;
else else
return -0.5250730514526367f; return -0.5250730514526367f;
else else if ((val & 0b0001) == 1) // 000
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f; return -0.6961928009986877f;
else else
return -1.0f; return -1.0f;
} }
__device__ unsigned char dQuantizeNF4(float x) __device__ unsigned char dQuantizeNF4(float x) {
{
// the values for this tree was generated by test_normal_map_tree // the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py // in the file tests/test_functional.py
if(x > 0.03979014977812767f) if (x > 0.03979014977812767f)
if(x > 0.3893125355243683f) // 1 if (x > 0.3893125355243683f) // 1
if(x > 0.6427869200706482f) // 11 if (x > 0.6427869200706482f) // 11
if(x > 0.8614784181118011f) // 111 if (x > 0.8614784181118011f) // 111
return 0b1111; return 0b1111;
else else
return 0b1110; return 0b1110;
else else if (x > 0.5016634166240692f) // 110
if(x > 0.5016634166240692f) // 110
return 0b1101; return 0b1101;
else else
return 0b1100; return 0b1100;
else else if (x > 0.2035212516784668f) // 10
if(x > 0.2035212516784668f) // 10 if (x > 0.2920137718319893f) // 101
if(x > 0.2920137718319893f) // 101
return 0b1011; return 0b1011;
else else
return 0b1010; return 0b1010;
else else if (x > 0.1202552504837513f) // 100
if(x > 0.1202552504837513f) // 100
return 0b1001; return 0b1001;
else else
return 0b1000; return 0b1000;
else else if (x > -0.33967943489551544f) // 0
if(x > -0.33967943489551544f) // 0 if (x > -0.13791173323988914f) // 01
if(x > -0.13791173323988914f) // 01 if (x > -0.045525018125772476f) // 011
if(x > -0.045525018125772476f) // 011
return 0b0111; return 0b0111;
else else
return 0b0110; return 0b0110;
else else if (x > -0.23460740596055984f) // 010
if(x > -0.23460740596055984f) // 010
return 0b0101; return 0b0101;
else else
return 0b0100; return 0b0100;
else else if (x > -0.6106329262256622f) // 00
if(x > -0.6106329262256622f) // 00 if (x > -0.4599952697753906f) // 001
if(x > -0.4599952697753906f) // 001
return 0b0011; return 0b0011;
else else
return 0b0010; return 0b0010;
else else if (x > -0.8480964004993439f) // 000
if(x > -0.8480964004993439f) // 000
return 0b0001; return 0b0001;
else else
return 0b0000; return 0b0000;
} }
// sign function for lion // sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA // taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template <typename T> __device__ int sgn(T val) template <typename T> __device__ int sgn(T val) { return (T(0) < val) - (val < T(0)); }
{
return (T(0) < val) - (val < T(0));
}
template <int STOCHASTIC> template <int STOCHASTIC> __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) {
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
int pivot = 127; int pivot = 127;
int upper_pivot = 255; int upper_pivot = 255;
int lower_pivot = 0; int lower_pivot = 0;
...@@ -240,71 +224,60 @@ __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) ...@@ -240,71 +224,60 @@ __device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
float val = smem_code[pivot]; float val = smem_code[pivot];
// i>>=1 = {32, 16, 8, 4, 2, 1} // i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1) for (int i = 64; i > 0; i >>= 1) {
{ if (x > val) {
if(x > val)
{
lower_pivot = pivot; lower_pivot = pivot;
lower = val; lower = val;
pivot+=i; pivot += i;
} } else {
else
{
upper_pivot = pivot; upper_pivot = pivot;
upper = val; upper = val;
pivot-=i; pivot -= i;
} }
val = smem_code[pivot]; val = smem_code[pivot];
} }
if(upper_pivot == 255) if (upper_pivot == 255)
upper = smem_code[upper_pivot]; upper = smem_code[upper_pivot];
if(lower_pivot == 0) if (lower_pivot == 0)
lower = smem_code[lower_pivot]; lower = smem_code[lower_pivot];
if(!STOCHASTIC) if (!STOCHASTIC) {
{ if (x > val) {
if(x > val) float midpoint = (upper + val) * 0.5f;
{ if (x > midpoint) {
float midpoint = (upper+val)*0.5f;
if(x > midpoint)
{
return upper_pivot; return upper_pivot;
} } else
else
return pivot; return pivot;
} } else {
else float midpoint = (lower + val) * 0.5f;
{ if (x < midpoint)
float midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot; return lower_pivot;
else else
return pivot; return pivot;
} }
} } else {
if (x > val) {
float dist_to_upper = fabsf(upper - x);
float dist_full = upper - val;
if (rand >= dist_to_upper / dist_full)
return upper_pivot;
else else
{ return pivot;
if(x > val) } else {
{ float dist_to_lower = fabsf(lower - x);
float dist_to_upper = fabsf(upper-x); float dist_full = val - lower;
float dist_full = upper-val; if (rand >= dist_to_lower / dist_full)
if(rand >= dist_to_upper/dist_full) return upper_pivot; return lower_pivot;
else return pivot;
}
else else
{ return pivot;
float dist_to_lower = fabsf(lower-x);
float dist_full = val-lower;
if(rand >= dist_to_lower/dist_full) return lower_pivot;
else return pivot;
} }
} }
} }
template <int SIGNED> template <int SIGNED>
__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) __device__ __forceinline__ unsigned char
{ quantize_2D(float* __restrict__ quadrants, float* __restrict__ const smem_code, float x) {
int pivot = 127; int pivot = 127;
int upper_pivot = 255; int upper_pivot = 255;
int lower_pivot = 0; int lower_pivot = 0;
...@@ -317,56 +290,48 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran ...@@ -317,56 +290,48 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran
int offset = 1; int offset = 1;
// i>>=1 = {32, 16, 8, 4, 2, 1} // i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1) for (int i = 64; i > 0; i >>= 1) {
{ if (x > val) {
if(x > val)
{
lower_pivot = pivot; lower_pivot = pivot;
lower = val; lower = val;
pivot+=i; pivot += i;
//val = i == 64 ? quadrants[2] : smem_code[pivot]; // val = i == 64 ? quadrants[2] : smem_code[pivot];
local_pivot += offset; local_pivot += offset;
} } else {
else
{
upper_pivot = pivot; upper_pivot = pivot;
upper = val; upper = val;
pivot-=i; pivot -= i;
//val = i == 64 ? quadrants[0] : smem_code[pivot]; // val = i == 64 ? quadrants[0] : smem_code[pivot];
local_pivot -= offset; local_pivot -= offset;
} }
val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot];
offset -= 1; offset -= 1;
} }
if(x > val) if (x > val) {
{ midpoint = (upper + val) * 0.5f;
midpoint = (upper+val)*0.5f; if (x > midpoint)
if(x > midpoint)
return upper_pivot; return upper_pivot;
else else
return pivot; return pivot;
} } else {
else midpoint = (lower + val) * 0.5f;
{ if (x < midpoint)
midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot; return lower_pivot;
else else
return pivot; return pivot;
} }
} }
__launch_bounds__(TH, 4) __launch_bounds__(TH, 4) __global__
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n) {
{ const int n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); int valid_items = (blockIdx.x + 1 == gridDim.x) ? n - (blockIdx.x * NUM_BLOCK) : NUM_BLOCK;
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK;
const int base_idx = (blockIdx.x * NUM_BLOCK); const int base_idx = (blockIdx.x * NUM_BLOCK);
float vals[NUM]; float vals[NUM];
unsigned char qvals[NUM]; unsigned char qvals[NUM];
//const int lane_id = threadIdx.x % 2; // const int lane_id = threadIdx.x % 2;
typedef cub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat; typedef cub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
...@@ -376,16 +341,13 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c ...@@ -376,16 +341,13 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
__shared__ float smem_code[256]; __shared__ float smem_code[256];
//__shared__ float smem_code[2][257]; //__shared__ float smem_code[2][257];
if(threadIdx.x < 256) if (threadIdx.x < 256) {
{
smem_code[threadIdx.x] = code[threadIdx.x]; smem_code[threadIdx.x] = code[threadIdx.x];
//smem_code[0][threadIdx.x] = code[threadIdx.x]; // smem_code[0][threadIdx.x] = code[threadIdx.x];
//smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; // smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x];
} }
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_BLOCK) {
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK)
{
// number of values already processed in blocks + // number of values already processed in blocks +
// number of values already processed in this block + // number of values already processed in this block +
// rand_offset % mod value // rand_offset % mod value
...@@ -394,9 +356,8 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c ...@@ -394,9 +356,8 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
__syncthreads(); __syncthreads();
LoadFloat(loadf).Load(&(A[i]), vals, valid_items); LoadFloat(loadf).Load(&(A[i]), vals, valid_items);
#pragma unroll 4
#pragma unroll 4 for (int j = 0; j < NUM; j++)
for(int j = 0; j < NUM; j++)
qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]);
__syncthreads(); __syncthreads();
...@@ -404,25 +365,30 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c ...@@ -404,25 +365,30 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
} }
} }
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> template <typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
//__launch_bounds__(TH, 4) //__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) __global__ void kQuantizeBlockwise(
{ float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
) {
const int n_full = gridDim.x * BLOCK_SIZE; const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0; int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE); const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM_PER_TH]; T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH];
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH];
//float local_abs_max = -FLT_MAX; // float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f; float local_abs_max = 0.0f;
int local_rand_idx = 0; int local_rand_idx = 0;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT; typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockStore<
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce; unsigned char, BLOCK_SIZE / NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH,
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat; cub::BLOCK_STORE_WARP_TRANSPOSE>
StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE / NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf; __shared__ typename LoadFloat::TempStorage loadf;
...@@ -431,12 +397,11 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float ...@@ -431,12 +397,11 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
__shared__ float smem_code[256]; __shared__ float smem_code[256];
__shared__ float smem_absmax_value[1]; __shared__ float smem_absmax_value[1];
if(DATA_TYPE == General8bit) if (DATA_TYPE == General8bit)
for(int i = threadIdx.x; i < 256; i+=blockDim.x) for (int i = threadIdx.x; i < 256; i += blockDim.x)
smem_code[i] = code[i]; smem_code[i] = code[i];
for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) for (int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_abs_max = -FLT_MAX; local_abs_max = -FLT_MAX;
...@@ -447,8 +412,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float ...@@ -447,8 +412,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
// 2. broadcast local max // 2. broadcast local max
// 3. normalize inputs and quantize // 3. normalize inputs and quantize
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++) for (int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
...@@ -461,60 +426,57 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float ...@@ -461,60 +426,57 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
local_abs_max = smem_absmax_value[0]; local_abs_max = smem_absmax_value[0];
if(STOCHASTIC) if (STOCHASTIC) {
{ local_rand_idx = ((blockIdx.x * NUM_BLOCK) + (threadIdx.x * NUM) + rand_offset) % (1024 - 4);
local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4);
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
} }
unsigned char packed_4bit = 0; unsigned char packed_4bit = 0;
switch(DATA_TYPE) switch (DATA_TYPE) {
{
case General8bit: case General8bit:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++) for (int j = 0; j < NUM_PER_TH; j++) {
{ if (!STOCHASTIC)
if(!STOCHASTIC) qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j]) * local_abs_max);
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
else else
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j]) * local_abs_max);
} }
break; break;
case FP4: case FP4:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++) for (int j = 0; j < NUM_PER_TH / 2; j++) {
{ packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit; qvals[j] = packed_4bit;
} }
break; break;
case NF4: case NF4:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++) for (int j = 0; j < NUM_PER_TH / 2; j++) {
{ packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit; qvals[j] = packed_4bit;
} }
break; break;
} }
__syncthreads(); __syncthreads();
StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); StoreChar(storec).Store(
&(out[(DATA_TYPE > 0) ? i / 2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items + 1) / 2 : valid_items
);
} }
} }
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) __global__ void
{ kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) {
const int n_load = (gridDim.x * TILE_SIZE); const int n_load = (gridDim.x * TILE_SIZE);
int valid_items_load = 0; int valid_items_load = 0;
int valid_items_store = 0; int valid_items_store = 0;
const int base_idx = (blockIdx.x * TILE_SIZE); const int base_idx = (blockIdx.x * TILE_SIZE);
T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)];
unsigned char qvals[NUM_PER_TH]; unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX; float local_abs_max = -FLT_MAX;
...@@ -524,15 +486,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs ...@@ -524,15 +486,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
__shared__ typename LoadChar::TempStorage loadchar; __shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet; __shared__ typename StoreT::TempStorage storet;
for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
{ if (DATA_TYPE > 0) {
if (DATA_TYPE > 0)
{
valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i);
valid_items_store = min(TILE_SIZE * 2, n - i * 2); valid_items_store = min(TILE_SIZE * 2, n - i * 2);
} } else {
else
{
valid_items_load = min(TILE_SIZE, n - i); valid_items_load = min(TILE_SIZE, n - i);
valid_items_store = valid_items_load; valid_items_store = valid_items_load;
} }
...@@ -540,72 +498,62 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs ...@@ -540,72 +498,62 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
// Since blocksize will always be a power-of-2, we avoid more expensive // Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation. // division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) >> (31 - __clz(blocksize))]);
__syncthreads(); __syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
switch (DATA_TYPE) switch (DATA_TYPE) {
{
case General8bit: case General8bit:
// load code through read-only cache via __ldg // load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++) for (int j = 0; j < NUM_PER_TH; j++)
vals[j] = __ldg(&code[qvals[j]])*local_abs_max; vals[j] = __ldg(&code[qvals[j]]) * local_abs_max;
break; break;
case FP4: case FP4:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++) for (int j = 0; j < NUM_PER_TH; j++) {
{ vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
} }
break; break;
case NF4: case NF4:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++) for (int j = 0; j < NUM_PER_TH; j++) {
{ vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max;
vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max;
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
} }
break; break;
} }
__syncthreads(); __syncthreads();
StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i * 2 : i]), vals, valid_items_store);
} }
} }
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) __global__ void kDequantize(float* code, unsigned char* A, float* out, const int n) {
{
const unsigned int numThreads = blockDim.x * gridDim.x; const unsigned int numThreads = blockDim.x * gridDim.x;
const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; const int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
__shared__ float smem_code[256]; __shared__ float smem_code[256];
if(threadIdx.x < 256) if (threadIdx.x < 256) {
{
smem_code[threadIdx.x] = code[threadIdx.x]; smem_code[threadIdx.x] = code[threadIdx.x];
} }
__syncthreads(); __syncthreads();
for (int i = idx;i < n; i += numThreads) for (int i = idx; i < n; i += numThreads) {
{
out[i] = smem_code[A[i]]; out[i] = smem_code[A[i]];
} }
} }
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
) {
const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0; int valid_items = 0;
...@@ -614,12 +562,12 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, ...@@ -614,12 +562,12 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float s1_vals[NUM_VALS]; float s1_vals[NUM_VALS];
float s2_vals[NUM_VALS]; float s2_vals[NUM_VALS];
const float correction1 = 1.0f/(1.0f - powf(beta1, step)); const float correction1 = 1.0f / (1.0f - powf(beta1, step));
const float correction2 = 1.0f/(1.0f - powf(beta2, step)); const float correction2 = 1.0f / (1.0f - powf(beta2, step));
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load; typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat; typedef cub::BlockLoad<float, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce; typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_VALS> BlockReduce;
__shared__ union { __shared__ union {
typename Load::TempStorage load; typename Load::TempStorage load;
...@@ -627,8 +575,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, ...@@ -627,8 +575,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
typename BlockReduce::TempStorage reduce; typename BlockReduce::TempStorage reduce;
} temp_storage; } temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
{
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads(); __syncthreads();
...@@ -638,60 +585,56 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, ...@@ -638,60 +585,56 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
__syncthreads(); __syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS #pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++) for (unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]); g_vals[j] = gnorm_scale * ((float)g_vals[j]);
# pragma unroll NUM_VALS #pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++) for (unsigned int j = 0; j < NUM_VALS; j++) {
{ switch (OPTIMIZER) {
switch(OPTIMIZER)
{
case ADAM: case ADAM:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j])));
s1_vals[j] *= correction1; s1_vals[j] *= correction1;
s2_vals[j] *= correction2; s2_vals[j] *= correction2;
s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update s1_vals[j] = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update
s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update)
break; break;
} }
} }
# pragma unroll NUM_VALS-1 #pragma unroll NUM_VALS - 1
for(unsigned int j = 1; j < NUM_VALS; j++) for (unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j]; s1_vals[0] += s1_vals[j];
__syncthreads(); __syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]);
if(threadIdx.x == 0) if (threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]); atomicAdd(&unorm[0], s1_vals[0]);
__syncwarp(); __syncwarp();
} }
} }
#define NUM_PER_THREAD 4 #define NUM_PER_THREAD 4
template<typename T, int OPTIMIZER> template <typename T, int OPTIMIZER>
__launch_bounds__(TH, 1) __launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State(
__global__ void kOptimizer32bit2State(T* g, T* p, T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) const int n
{ ) {
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) +
(n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0; int valid_items = 0;
float update_scale = 0.0f; float update_scale = 0.0f;
T g_vals[NUM_PER_THREAD]; T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD]; T p_vals[NUM_PER_THREAD];
float s1_vals[NUM_PER_THREAD]; float s1_vals[NUM_PER_THREAD];
float s2_vals[NUM_PER_THREAD]; float s2_vals[NUM_PER_THREAD];
...@@ -700,18 +643,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p, ...@@ -700,18 +643,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
// TODO: Mark with [[maybe_unused]] after upgrade to min compiler. // TODO: Mark with [[maybe_unused]] after upgrade to min compiler.
float s3_vals[NUM_PER_THREAD]; float s3_vals[NUM_PER_THREAD];
const float correction1 = 1.0f - powf(beta1, step); const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step)); const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1; const float step_size = -lr * correction2 / correction1;
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } if (update_scale > max_unorm * param_norm) {
else{ update_scale = 1.0f; } update_scale = (max_unorm * param_norm) / update_scale;
} else {
update_scale = 1.0f;
}
} else {
update_scale = 1.0f;
} }
else{ update_scale = 1.0f; }
typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load; typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store; typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
...@@ -726,9 +671,8 @@ __global__ void kOptimizer32bit2State(T* g, T* p, ...@@ -726,9 +671,8 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
typename StoreFloat::TempStorage storef; typename StoreFloat::TempStorage storef;
} temp_storage; } temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) {
{ valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads(); __syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
...@@ -746,15 +690,13 @@ __global__ void kOptimizer32bit2State(T* g, T* p, ...@@ -746,15 +690,13 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items);
} }
# pragma unroll 4 #pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++) for (unsigned int j = 0; j < NUM_PER_THREAD; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]); g_vals[j] = gnorm_scale * ((float)g_vals[j]);
# pragma unroll 4 #pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++) for (unsigned int j = 0; j < NUM_PER_THREAD; j++) {
{ switch (OPTIMIZER) {
switch(OPTIMIZER)
{
case ADEMAMIX: case ADEMAMIX:
// m1 update: m1 = beta1 * m1 + (1-beta1) * g // m1 update: m1 = beta1 * m1 + (1-beta1) * g
s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]);
...@@ -765,11 +707,8 @@ __global__ void kOptimizer32bit2State(T* g, T* p, ...@@ -765,11 +707,8 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
// nu update: nu = beta2 * nu + (1-beta2) * g^2 // nu update: nu = beta2 * nu + (1-beta2) * g^2
s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]);
p_vals[j] = (float)p_vals[j] - lr * ( p_vals[j] = (float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) /
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( ((sqrtf(s2_vals[j]) / correction2) + eps));
(sqrtf(s2_vals[j]) / correction2) + eps
)
);
if (weight_decay > 0.0f) if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
...@@ -777,14 +716,14 @@ __global__ void kOptimizer32bit2State(T* g, T* p, ...@@ -777,14 +716,14 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
break; break;
case ADAM: case ADAM:
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) {
{ s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]));
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j])));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) +
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); (update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (eps * correction2))));
if(weight_decay > 0.0f) if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
} }
break; break;
} }
...@@ -804,15 +743,13 @@ __global__ void kOptimizer32bit2State(T* g, T* p, ...@@ -804,15 +743,13 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
} }
} }
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State(
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps,
float* state1, float *unorm, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
const float beta1, const float beta2, const float eps, const float weight_decay, ) {
const int step, const float lr, const float gnorm_scale, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0; int valid_items = 0;
...@@ -820,9 +757,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, ...@@ -820,9 +757,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float s1_vals[NUM_VALS]; float s1_vals[NUM_VALS];
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load; typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat; typedef cub::BlockLoad<float, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce; typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_VALS> BlockReduce;
__shared__ union { __shared__ union {
typename Load::TempStorage load; typename Load::TempStorage load;
...@@ -830,8 +767,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, ...@@ -830,8 +767,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
typename BlockReduce::TempStorage reduce; typename BlockReduce::TempStorage reduce;
} temp_storage; } temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
{
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads(); __syncthreads();
...@@ -839,72 +775,74 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, ...@@ -839,72 +775,74 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
__syncthreads(); __syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS #pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++) for (unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]); g_vals[j] = gnorm_scale * ((float)g_vals[j]);
# pragma unroll NUM_VALS #pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++) for (unsigned int j = 0; j < NUM_VALS; j++) {
{ switch (OPTIMIZER) {
switch(OPTIMIZER)
{
case MOMENTUM: case MOMENTUM:
if(step == 1) if (step == 1)
s1_vals[j] = (float)g_vals[j]; // state update s1_vals[j] = (float)g_vals[j]; // state update
else else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm
break; break;
case LION: case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * (float)g_vals[j]); // state update
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] =
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j])); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value
s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm
break; break;
case ADAGRAD: case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm
break; break;
} }
} }
# pragma unroll #pragma unroll
for(unsigned int j = 1; j < NUM_VALS; j++) for (unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j]; s1_vals[0] += s1_vals[j];
__syncthreads(); __syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items);
if(threadIdx.x == 0) if (threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]); atomicAdd(&unorm[0], s1_vals[0]);
__syncwarp(); __syncwarp();
} }
} }
template<typename T, int OPTIMIZER> template <typename T, int OPTIMIZER>
__launch_bounds__(TH, 1) __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State(
__global__ void kOptimizer32bit1State(T *g, T *p, T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1,
float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta2, const float eps, const float weight_decay, const int step, const float lr,
const float beta1, const float beta2, const float eps, const float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) ) {
{
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) +
(n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0; int valid_items = 0;
float update_scale = 0.0f; float update_scale = 0.0f;
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } if (update_scale > max_unorm * param_norm + eps) {
else{ update_scale = 1.0f; } update_scale = (max_unorm * param_norm + eps) / update_scale;
} else {
update_scale = 1.0f;
}
} else {
update_scale = 1.0f;
} }
else{ update_scale = 1.0f; }
T g_vals[NUM_PER_THREAD]; T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD]; T p_vals[NUM_PER_THREAD];
...@@ -924,9 +862,8 @@ __global__ void kOptimizer32bit1State(T *g, T *p, ...@@ -924,9 +862,8 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
typename StoreFloat::TempStorage storef; typename StoreFloat::TempStorage storef;
} temp_storage; } temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) {
{ valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads(); __syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
...@@ -935,40 +872,39 @@ __global__ void kOptimizer32bit1State(T *g, T *p, ...@@ -935,40 +872,39 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
__syncthreads(); __syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
# pragma unroll 4 #pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++) for (unsigned int j = 0; j < NUM_PER_THREAD; j++) {
{ g_vals[j] = gnorm_scale * ((float)g_vals[j]);
g_vals[j] = gnorm_scale*((float)g_vals[j]); if (weight_decay > 0.0f)
if(weight_decay > 0.0f) g_vals[j] = (float)g_vals[j] + (((float)p_vals[j]) * weight_decay);
g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
} }
# pragma unroll 4 #pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++) for (unsigned int j = 0; j < NUM_PER_THREAD; j++) {
{ if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) {
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) switch (OPTIMIZER) {
{
switch(OPTIMIZER)
{
case MOMENTUM: case MOMENTUM:
if(step == 1) if (step == 1)
s1_vals[j] = (float)g_vals[j]; s1_vals[j] = (float)g_vals[j];
else else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); p_vals[j] = ((float)p_vals[j]) + update_scale * (-lr * (s1_vals[j]));
break; break;
case LION: case LION:
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); p_vals[j] =
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); ((float)p_vals[j]) -
update_scale * (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_vals[j]))));
s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * ((float)g_vals[j]));
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) -
update_scale * (lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps));
break; break;
case ADAGRAD: case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); p_vals[j] = ((float)p_vals[j]) - lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps);
break; break;
} }
} }
...@@ -981,25 +917,21 @@ __global__ void kOptimizer32bit1State(T *g, T *p, ...@@ -981,25 +917,21 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
} }
} }
#define NUM8BIT 16 #define NUM8BIT 16
#define NUM_THREADS 256 #define NUM_THREADS 256
#define NUM_PER_BLOCK 4096 #define NUM_PER_BLOCK 4096
template<typename T, int OPTIMIZER> template <typename T, int OPTIMIZER>
__global__ void __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit2State(
__launch_bounds__(NUM_THREADS, 2) T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2,
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, float* unorm, const float beta1, const float beta2, const float eps, const int step,
float *unorm, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
const float beta1, const float beta2, float* new_max1, float* new_max2, const float gnorm_scale, const int n
const float eps, const int step, ) {
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n)
{
const int n_full = gridDim.x * NUM_PER_BLOCK; const int n_full = gridDim.x * NUM_PER_BLOCK;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); int valid_items =
n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK);
float g_val = 0.0f; float g_val = 0.0f;
float local_max_s1 = -FLT_MAX; float local_max_s1 = -FLT_MAX;
float local_max_s2 = -FLT_MAX; float local_max_s2 = -FLT_MAX;
...@@ -1015,7 +947,6 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c ...@@ -1015,7 +947,6 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8; typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce; typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union { __shared__ union {
typename LoadT::TempStorage loadh; typename LoadT::TempStorage loadh;
typename LoadUInt8::TempStorage loadc; typename LoadUInt8::TempStorage loadc;
...@@ -1025,17 +956,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c ...@@ -1025,17 +956,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
__shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles1[256];
__shared__ float smem_quantiles2[256]; __shared__ float smem_quantiles2[256];
if(threadIdx.x < 256) if (threadIdx.x < 256) {
{
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x];
} }
__syncthreads(); __syncthreads();
for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS * gridDim.x * NUM8BIT) {
{ valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads(); __syncthreads();
...@@ -1044,37 +973,33 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c ...@@ -1044,37 +973,33 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128);
__syncthreads(); __syncthreads();
#pragma unroll 16 #pragma unroll 16
for(int j = 0; j < NUM8BIT; j++) for (int j = 0; j < NUM8BIT; j++) {
{
g_val = g_vals[j]; g_val = g_vals[j];
g_val *= gnorm_scale; g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0] * beta1;
s1_vals[j] += (1.0f-beta1)*g_val; s1_vals[j] += (1.0f - beta1) * g_val;
local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
} }
#pragma unroll 16 #pragma unroll 16
for(int j = 0; j < NUM8BIT; j++) for (int j = 0; j < NUM8BIT; j++) {
{
g_val = g_vals[j]; g_val = g_vals[j];
g_val *= gnorm_scale; g_val *= gnorm_scale;
s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; s2_vals[j] = smem_quantiles2[r_c2[j]] * max2[0] * beta2;
s2_vals[j] += (1.0f-beta2)*g_val*g_val; s2_vals[j] += (1.0f - beta2) * g_val * g_val;
local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j]));
} }
if(unorm != NULL) if (unorm != NULL) {
{ #pragma unroll 16
#pragma unroll 16 for (int j = 0; j < NUM8BIT; j++) {
for(int j = 0; j < NUM8BIT; j++)
{
float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step));
float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step));
s1_vals[j] *= correction1; s1_vals[j] *= correction1;
s2_vals[j] *= correction2; s2_vals[j] *= correction2;
float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update float update_val = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update
local_unorm += update_val*update_val; local_unorm += update_val * update_val;
} }
} }
} }
...@@ -1083,17 +1008,17 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c ...@@ -1083,17 +1008,17 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
__syncthreads(); __syncthreads();
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items);
if(unorm != NULL) if (unorm != NULL) {
{
__syncthreads(); __syncthreads();
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
} }
if(threadIdx.x == 0) if (threadIdx.x == 0) {
{
atomicMax(&new_max1[0], local_max_s1); atomicMax(&new_max1[0], local_max_s1);
atomicMax(&new_max2[0], local_max_s2); atomicMax(&new_max2[0], local_max_s2);
if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } if (unorm != NULL) {
atomicAdd(&unorm[0], local_unorm);
}
} }
} }
...@@ -1101,20 +1026,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c ...@@ -1101,20 +1026,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
#define NUM_THREADS2 1024 #define NUM_THREADS2 1024
#define NUM_PER_BLOCK2 4096 #define NUM_PER_BLOCK2 4096
template<typename T, int OPTIMIZER> template <typename T, int OPTIMIZER>
__global__ void __global__ void __launch_bounds__(NUM_THREADS2, 1) kOptimizerStatic8bit2State(
__launch_bounds__(NUM_THREADS2, 1) T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm,
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr,
const float *unorm, const float max_unorm, const float param_norm, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
const float beta1, const float beta2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n
const float eps, const int step, const float lr, ) {
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2, const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2;
float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
int valid_items = 0; int valid_items = 0;
float g_val = 0.0f; float g_val = 0.0f;
...@@ -1122,19 +1042,22 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha ...@@ -1122,19 +1042,22 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
float s2_vals[NUM_PER_THREAD2]; float s2_vals[NUM_PER_THREAD2];
const float correction1 = 1.0f - powf(beta1, step); const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step)); const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1; const float step_size = -lr * correction2 / correction1;
//const float step_size = -lr*correction2/correction1; // const float step_size = -lr*correction2/correction1;
float new_max_val1 = 1.0f/new_max1[0]; float new_max_val1 = 1.0f / new_max1[0];
float new_max_val2 = 1.0f/new_max2[0]; float new_max_val2 = 1.0f / new_max2[0];
float update_scale = 1.0f; float update_scale = 1.0f;
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } if (update_scale > max_unorm * param_norm) {
else{ update_scale = 1.0f; } update_scale = (max_unorm * param_norm) / update_scale;
} else {
update_scale = 1.0f;
}
} else {
update_scale = 1.0f;
} }
else{ update_scale = 1.0f; }
unsigned char c1s[NUM_PER_THREAD2]; unsigned char c1s[NUM_PER_THREAD2];
unsigned char c2s[NUM_PER_THREAD2]; unsigned char c2s[NUM_PER_THREAD2];
...@@ -1156,19 +1079,17 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha ...@@ -1156,19 +1079,17 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
typename StoreT::TempStorage storeh; typename StoreT::TempStorage storeh;
} temp_storage; } temp_storage;
if(threadIdx.x < 512) if (threadIdx.x < 512) {
{ if (threadIdx.x < 256)
if(threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
else else
smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; smem_quantiles2[threadIdx.x - 256] = quantiles2[threadIdx.x - 256];
} }
__syncthreads(); __syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) {
{ valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads(); __syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
...@@ -1177,42 +1098,42 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha ...@@ -1177,42 +1098,42 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
__syncthreads(); __syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) {
continue;
}
# pragma unroll 4 #pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) {
{
g_val = float(g_vals[j]); g_val = float(g_vals[j]);
g_val *= gnorm_scale; g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[c1s[j]]; s1_vals[j] = smem_quantiles1[c1s[j]];
s1_vals[j] = s1_vals[j]*max1[0]; s1_vals[j] = s1_vals[j] * max1[0];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val));
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1);
// make sure state1 term has still the same sign after quantization // make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values) // (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) {
{ if (s1_vals[j] > 0.0f)
if(s1_vals[j] > 0.0f)
c1s[j] += 1; c1s[j] += 1;
else else
c1s[j] -= 1; c1s[j] -= 1;
} }
s2_vals[j] = smem_quantiles2[c2s[j]]; s2_vals[j] = smem_quantiles2[c2s[j]];
s2_vals[j] = s2_vals[j]*max2[0]; s2_vals[j] = s2_vals[j] * max2[0];
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val));
c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j] * new_max_val2);
} }
# pragma unroll 4 #pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) {
{ p_vals[j] = (T)(((float)p_vals[j]) +
p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); ((update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (correction2 * eps))))));
if(weight_decay > 0.0f) if (weight_decay > 0.0f)
p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); p_vals[j] = update_scale * ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
} }
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
...@@ -1224,22 +1145,16 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha ...@@ -1224,22 +1145,16 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
} }
} }
template <typename T, int OPTIMIZER>
template<typename T, int OPTIMIZER> __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit1State(
__global__ void T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1,
__launch_bounds__(NUM_THREADS, 2) const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1,
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n
float *unorm, ) {
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
const float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = gridDim.x * NUM_PER_BLOCK; const int n_full = gridDim.x * NUM_PER_BLOCK;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); int valid_items =
n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK);
float g_val = 0.0f; float g_val = 0.0f;
float local_max_s1 = -FLT_MAX; float local_max_s1 = -FLT_MAX;
float local_unorm = 0.0f; float local_unorm = 0.0f;
...@@ -1252,7 +1167,6 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c ...@@ -1252,7 +1167,6 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8; typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce; typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union { __shared__ union {
typename LoadT::TempStorage loadh; typename LoadT::TempStorage loadh;
typename LoadUInt8::TempStorage loadc; typename LoadUInt8::TempStorage loadc;
...@@ -1261,42 +1175,39 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c ...@@ -1261,42 +1175,39 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
__shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles1[256];
if(threadIdx.x < 256) if (threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
__syncthreads(); __syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS * NUM8BIT) {
{ valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads(); __syncthreads();
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads(); __syncthreads();
LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
#pragma unroll 16 #pragma unroll 16
for(int j = 0; j < NUM8BIT; j++) for (int j = 0; j < NUM8BIT; j++) {
{
g_val = g_vals[j]; g_val = g_vals[j];
g_val *= gnorm_scale; g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0];
switch(OPTIMIZER) switch (OPTIMIZER) {
{
case ADAGRAD: case ADAGRAD:
case MOMENTUM: case MOMENTUM:
if(step == 1) if (step == 1)
s1_vals[j] = (float)g_vals[j]; s1_vals[j] = (float)g_vals[j];
else else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]);
if(unorm != NULL) if (unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j]; local_unorm += s1_vals[j] * s1_vals[j];
break; break;
case LION: case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val);
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val));
break; break;
} }
...@@ -1306,44 +1217,44 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c ...@@ -1306,44 +1217,44 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
__syncthreads(); __syncthreads();
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } if (threadIdx.x == 0) {
if(unorm != NULL) atomicMax(&new_max1[0], local_max_s1);
{ }
if (unorm != NULL) {
__syncthreads(); __syncthreads();
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } if (threadIdx.x == 0) {
atomicAdd(&unorm[0], local_unorm);
}
} }
} }
template<typename T, int OPTIMIZER> template <typename T, int OPTIMIZER>
__global__ void __global__ void __launch_bounds__(1024, 1) kOptimizerStatic8bit1State(
__launch_bounds__(1024, 1) T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm,
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float beta1, const float beta2, const float eps, const int step, const float lr,
const float *unorm, const float max_unorm, const float param_norm, float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale,
const float beta1, const float beta2, const int n
const float eps, const int step, const float lr, ) {
float* __restrict__ const quantiles1,
float* max1, float* new_max1, const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2;
float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
int valid_items = 0; int valid_items = 0;
float g_val = 0.0f; float g_val = 0.0f;
float s1_vals[NUM_PER_THREAD2]; float s1_vals[NUM_PER_THREAD2];
float new_max_val1 = 1.0f/new_max1[0]; float new_max_val1 = 1.0f / new_max1[0];
float update_scale = 1.0f; float update_scale = 1.0f;
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } if (update_scale > max_unorm * param_norm) {
else{ update_scale = 1.0f; } update_scale = (max_unorm * param_norm) / update_scale;
} else {
update_scale = 1.0f;
}
} else {
update_scale = 1.0f;
} }
else{ update_scale = 1.0f; }
unsigned char c1s[NUM_PER_THREAD2]; unsigned char c1s[NUM_PER_THREAD2];
T p_vals[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2];
...@@ -1363,69 +1274,69 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, ...@@ -1363,69 +1274,69 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
typename StoreT::TempStorage storeh; typename StoreT::TempStorage storeh;
} temp_storage; } temp_storage;
if(threadIdx.x < 256) if (threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
__syncthreads(); __syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) {
{ valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i;
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads(); __syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads(); __syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) {
continue;
}
# pragma unroll 4 #pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) {
{
g_val = float(g_vals[j]); g_val = float(g_vals[j]);
g_val *= gnorm_scale; g_val *= gnorm_scale;
if(weight_decay > 0.0f) { if (weight_decay > 0.0f) {
switch(OPTIMIZER) { switch (OPTIMIZER) {
case ADAGRAD: case ADAGRAD:
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay; g_val += ((float)p_vals[j]) * weight_decay;
break; break;
case LION: case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay);
break; break;
} }
} }
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; s1_vals[j] = smem_quantiles1[c1s[j]] * max1[0];
switch(OPTIMIZER){ switch (OPTIMIZER) {
case ADAGRAD: case ADAGRAD:
case MOMENTUM: case MOMENTUM:
if(step == 1) if (step == 1)
s1_vals[j] = g_vals[j]; s1_vals[j] = g_vals[j];
else else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); p_vals[j] = ((float)p_vals[j]) + (-lr * update_scale * (s1_vals[j]));
break; break;
case LION: case LION:
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); p_vals[j] =
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); ((float)p_vals[j]) - (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_val))));
s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val);
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - (lr * __fdividef(g_val, sqrtf(s1_vals[j]) + eps));
break; break;
} }
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1);
// make sure state1 term has still the same sign after quantization // make sure state1 term has still the same sign after quantization
if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) {
{ if (s1_vals[j] > 0.0f)
if(s1_vals[j] > 0.0f)
c1s[j] += 1; c1s[j] += 1;
else else
c1s[j] -= 1; c1s[j] -= 1;
...@@ -1439,15 +1350,13 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, ...@@ -1439,15 +1350,13 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
} }
} }
template <typename T, int BLOCK_SIZE, int NUM_VALS>
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n) {
__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
int valid_items = 0; int valid_items = 0;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce; typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_VALS> BlockReduce;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT; typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename BlockReduce::TempStorage reduce; __shared__ typename BlockReduce::TempStorage reduce;
...@@ -1455,64 +1364,42 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st ...@@ -1455,64 +1364,42 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
T vals[NUM_VALS]; T vals[NUM_VALS];
float local_sum = 0.0f; float local_sum = 0.0f;
for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x * BLOCK_SIZE) {
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_sum = 0.0f; local_sum = 0.0f;
__syncthreads(); __syncthreads();
LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f);
#pragma unroll NUM_VALS #pragma unroll NUM_VALS
for(int j = 0; j < NUM_VALS; j++) for (int j = 0; j < NUM_VALS; j++)
local_sum += ((float)vals[j])*((float)vals[j]); local_sum += ((float)vals[j]) * ((float)vals[j]);
local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); local_sum = BlockReduce(reduce).Sum(local_sum, valid_items);
if(threadIdx.x == 0) if (threadIdx.x == 0) {
{ if (step == 1) {
if(step == 1)
{
// initialize with the same norm for all positions // initialize with the same norm for all positions
//#pragma unroll 10 // #pragma unroll 10
for(int j = 0; j < 100; j++) for (int j = 0; j < 100; j++)
atomicAdd(&gnorm_vec[j], local_sum); atomicAdd(&gnorm_vec[j], local_sum);
} } else
else
atomicAdd(&gnorm_vec[step % 100], local_sum); atomicAdd(&gnorm_vec[step % 100], local_sum);
} }
} }
} }
#define LANES 2 #define LANES 2
#define QUAD 3 #define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3) template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise(
kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2,
T* p, const float beta3, const float alpha, const float eps, const int step, const float lr,
T* __restrict__ const g, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2,
unsigned char* state1, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n
unsigned char* state2,
const float beta1,
const float beta2,
const float beta3,
const float alpha,
const float eps,
const int step,
const float lr,
float* __restrict__ const quantiles1,
float* __restrict__ const quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
const bool skip_zeros,
const int n
) { ) {
//const int n_full = n + (n%BLOCK_SIZE); // const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE; const int n_full = gridDim.x * BLOCK_SIZE;
const int base_idx = (blockIdx.x * BLOCK_SIZE); const int base_idx = (blockIdx.x * BLOCK_SIZE);
int valid_items = 0; int valid_items = 0;
...@@ -1523,8 +1410,8 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1523,8 +1410,8 @@ kOptimizerStatic8bit2StateBlockwise(
// 2-5% // 2-5%
const float correction1 = 1.0f - __powf(beta1, step); const float correction1 = 1.0f - __powf(beta1, step);
const float correction2 = sqrtf(1.0f -__powf(beta2, step)); const float correction2 = sqrtf(1.0f - __powf(beta2, step));
const float step_size = __fdividef(-lr*correction2,correction1); const float step_size = __fdividef(-lr * correction2, correction1);
const int lane_id = threadIdx.x % LANES; const int lane_id = threadIdx.x % LANES;
float new_local_abs_max1 = -FLT_MAX; float new_local_abs_max1 = -FLT_MAX;
float new_local_abs_max2 = -FLT_MAX; float new_local_abs_max2 = -FLT_MAX;
...@@ -1538,17 +1425,17 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1538,17 +1425,17 @@ kOptimizerStatic8bit2StateBlockwise(
T g_vals[N_PER_TH]; T g_vals[N_PER_TH];
T p_vals[N_PER_TH]; T p_vals[N_PER_TH];
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT; typedef cub::BlockLoad<T, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar; typedef cub::BlockLoad<unsigned char, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockStore<unsigned char, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; typedef cub::BlockStore<T, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257]; __shared__ float smem_quantiles1[LANES][257];
__shared__ float smem_quantiles2[LANES][257]; __shared__ float smem_quantiles2[LANES][257];
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1; typedef cub::BlockReduce<float, BLOCK_SIZE / N_PER_TH> BlockReduce1;
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2; typedef cub::BlockReduce<float, BLOCK_SIZE / N_PER_TH> BlockReduce2;
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce3; typedef cub::BlockReduce<float, BLOCK_SIZE / N_PER_TH> BlockReduce3;
__shared__ typename BlockReduce1::TempStorage reduce1; __shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ typename BlockReduce2::TempStorage reduce2; __shared__ typename BlockReduce2::TempStorage reduce2;
__shared__ typename BlockReduce2::TempStorage reduce3; __shared__ typename BlockReduce2::TempStorage reduce3;
...@@ -1562,30 +1449,27 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1562,30 +1449,27 @@ kOptimizerStatic8bit2StateBlockwise(
typename StoreChar::TempStorage storec; typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh; typename StoreT::TempStorage storeh;
} temp_storage; } temp_storage;
// init: 0.2 -> 0.23 // init: 0.2 -> 0.23
// 0.23 -> 0.23 // 0.23 -> 0.23
smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x];
# pragma unroll #pragma unroll
for(unsigned int j = 1; j < LANES; j++) for (unsigned int j = 1; j < LANES; j++) {
{
smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for(int k = 0; k < QUAD; k++) for (int k = 0; k < QUAD; k++) {
{ quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)];
quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; quadrants2[k] = smem_quantiles2[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)];
quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
} }
for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
// loads: 0.23 -> 0.85/1.44 // loads: 0.23 -> 0.85/1.44
valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
__syncthreads(); __syncthreads();
...@@ -1605,31 +1489,27 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1605,31 +1489,27 @@ kOptimizerStatic8bit2StateBlockwise(
new_local_abs_max2 = -FLT_MAX; new_local_abs_max2 = -FLT_MAX;
new_local_abs_max3 = -FLT_MAX; new_local_abs_max3 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60 // update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH #pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++) for (unsigned int j = 0; j < N_PER_TH; j++) {
{ if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) {
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) s2_vals[j] = smem_quantiles2[lane_id][c2s[j]] * absmax2[i / BLOCK_SIZE];
{
s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
g_val = g_vals[j]; g_val = g_vals[j];
//float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); // float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps);
//g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; // g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val;
g_val *= gnorm_scale; g_val *= gnorm_scale;
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val));
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val));
if (OPTIMIZER == ADEMAMIX) { if (OPTIMIZER == ADEMAMIX) {
// The absmax for the third state is appended to absmax1 // The absmax for the third state is appended to absmax1
s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE]; s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i) / BLOCK_SIZE];
s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val));
} }
} } else {
else
{
s1_vals[j] = 0.0f; s1_vals[j] = 0.0f;
s2_vals[j] = 0.0f; s2_vals[j] = 0.0f;
...@@ -1646,7 +1526,6 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1646,7 +1526,6 @@ kOptimizerStatic8bit2StateBlockwise(
} }
} }
// reduce: 2.51/1.60 -> 2.67/1.69 // reduce: 2.51/1.60 -> 2.67/1.69
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max());
...@@ -1655,8 +1534,7 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1655,8 +1534,7 @@ kOptimizerStatic8bit2StateBlockwise(
new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max()); new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max());
} }
if(threadIdx.x == 0) if (threadIdx.x == 0) {
{
smem_exchange1[0] = new_local_abs_max1; smem_exchange1[0] = new_local_abs_max1;
smem_exchange2[0] = new_local_abs_max2; smem_exchange2[0] = new_local_abs_max2;
...@@ -1667,17 +1545,14 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1667,17 +1545,14 @@ kOptimizerStatic8bit2StateBlockwise(
__syncthreads(); __syncthreads();
if(threadIdx.x == 0) if (threadIdx.x == 0) {
{ absmax1[i / BLOCK_SIZE] = new_local_abs_max1;
absmax1[i/BLOCK_SIZE] = new_local_abs_max1; absmax2[i / BLOCK_SIZE] = new_local_abs_max2;
absmax2[i/BLOCK_SIZE] = new_local_abs_max2;
if (OPTIMIZER == ADEMAMIX) { if (OPTIMIZER == ADEMAMIX) {
absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3; absmax1[(n + i) / BLOCK_SIZE] = new_local_abs_max3;
}
} }
else } else {
{
new_local_abs_max1 = smem_exchange1[0]; new_local_abs_max1 = smem_exchange1[0];
new_local_abs_max2 = smem_exchange2[0]; new_local_abs_max2 = smem_exchange2[0];
...@@ -1688,25 +1563,23 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1688,25 +1563,23 @@ kOptimizerStatic8bit2StateBlockwise(
__syncthreads(); __syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
// reduce: 2.67/1.69 -> 2.67/1.70 // reduce: 2.67/1.69 -> 2.67/1.70
# pragma unroll N_PER_TH #pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++) for (unsigned int j = 0; j < N_PER_TH; j++) {
{ // if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
//if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) {
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
{
if (OPTIMIZER == ADEMAMIX) { if (OPTIMIZER == ADEMAMIX) {
p_vals[j] = T((float)p_vals[j] - lr * ( p_vals[j] =
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( T((float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) /
(sqrtf(s2_vals[j]) / correction2) + eps ((sqrtf(s2_vals[j]) / correction2) + eps)));
)
));
} else { } else {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); p_vals[j] =
(T)(((float)p_vals[j]) +
((step_size * (__fdividef(s1_vals[j], (sqrtf(s2_vals[j]) + (correction2 * eps)))))));
} }
if(weight_decay > 0.0f) if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
} }
} }
...@@ -1714,25 +1587,24 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1714,25 +1587,24 @@ kOptimizerStatic8bit2StateBlockwise(
__syncthreads(); __syncthreads();
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3 // quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH #pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++) for (unsigned int j = 0; j < N_PER_TH; j++) {
{ c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1));
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j], new_local_abs_max2));
c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2));
// make sure state1 term has still the same sign after quantization // make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values) // (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) {
{ if (s1_vals[j] > 0.0f)
if(s1_vals[j] > 0.0f)
c1s[j] += 1; c1s[j] += 1;
else else
c1s[j] -= 1; c1s[j] -= 1;
} }
if (OPTIMIZER == ADEMAMIX) { if (OPTIMIZER == ADEMAMIX) {
c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3)); c3s[j] =
quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j], new_local_abs_max3));
if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) {
c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1;
...@@ -1752,22 +1624,17 @@ kOptimizerStatic8bit2StateBlockwise( ...@@ -1752,22 +1624,17 @@ kOptimizerStatic8bit2StateBlockwise(
} }
} }
#define LANES 2 #define LANES 2
#define QUAD 3 #define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3) template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise(
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps,
const float beta1, const float beta2, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay,
const float eps, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n
float* __restrict__ const quantiles1, ) {
float* absmax1,
float weight_decay, // const int n_full = n + (n%BLOCK_SIZE);
const float gnorm_scale, const bool skip_zeros, const int n)
{
//const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE; const int n_full = gridDim.x * BLOCK_SIZE;
const int base_idx = (blockIdx.x * BLOCK_SIZE); const int base_idx = (blockIdx.x * BLOCK_SIZE);
int valid_items = 0; int valid_items = 0;
...@@ -1782,14 +1649,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1782,14 +1649,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
T g_vals[N_PER_TH]; T g_vals[N_PER_TH];
T p_vals[N_PER_TH]; T p_vals[N_PER_TH];
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT; typedef cub::BlockLoad<T, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar; typedef cub::BlockLoad<unsigned char, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockStore<unsigned char, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; typedef cub::BlockStore<T, BLOCK_SIZE / N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257]; __shared__ float smem_quantiles1[LANES][257];
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1; typedef cub::BlockReduce<float, BLOCK_SIZE / N_PER_TH> BlockReduce1;
__shared__ typename BlockReduce1::TempStorage reduce1; __shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ float smem_exchange1[1]; __shared__ float smem_exchange1[1];
...@@ -1799,22 +1666,22 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1799,22 +1666,22 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
typename StoreChar::TempStorage storec; typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh; typename StoreT::TempStorage storeh;
} temp_storage; } temp_storage;
// init: 0.2 -> 0.23 // init: 0.2 -> 0.23
// 0.23 -> 0.23 // 0.23 -> 0.23
smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
# pragma unroll #pragma unroll
for(unsigned int j = 1; j < LANES; j++) for (unsigned int j = 1; j < LANES; j++)
smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for(int k = 0; k < QUAD; k++) for (int k = 0; k < QUAD; k++)
quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
{
// loads: 0.23 -> 0.85/1.44 // loads: 0.23 -> 0.85/1.44
valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
__syncthreads(); __syncthreads();
...@@ -1826,47 +1693,45 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1826,47 +1693,45 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
new_local_abs_max1 = -FLT_MAX; new_local_abs_max1 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60 // update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH #pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++) for (unsigned int j = 0; j < N_PER_TH; j++) {
{
g_val = float(g_vals[j]); g_val = float(g_vals[j]);
g_val *= gnorm_scale; g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) {
{ if (weight_decay > 0.0f) {
if(weight_decay > 0.0f) { switch (OPTIMIZER) {
switch(OPTIMIZER) {
case MOMENTUM: case MOMENTUM:
case ADAGRAD: case ADAGRAD:
case RMSPROP: case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay; g_val += ((float)p_vals[j]) * weight_decay;
break; break;
case LION: case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay);
break; break;
} }
} }
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE];
switch(OPTIMIZER) switch (OPTIMIZER) {
{
case MOMENTUM: case MOMENTUM:
if(step == 1) if (step == 1)
s1_vals[j] = g_val; s1_vals[j] = g_val;
else else
s1_vals[j] = (s1_vals[j]*beta1) + g_val; s1_vals[j] = (s1_vals[j] * beta1) + g_val;
break; break;
case LION: case LION:
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update,
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); // before the momentum is updated by beta2
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); g_vals[j] = lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * g_val));
s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val);
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val));
break; break;
case ADAGRAD: case ADAGRAD:
s1_vals[j] = s1_vals[j] + (g_val*g_val); s1_vals[j] = s1_vals[j] + (g_val * g_val);
break; break;
} }
} }
...@@ -1874,41 +1739,37 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1874,41 +1739,37 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
} }
// reduce: 2.51/1.60 -> 2.67/1.69 // reduce: 2.51/1.60 -> 2.67/1.69
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
if(threadIdx.x == 0) if (threadIdx.x == 0)
smem_exchange1[0] = new_local_abs_max1; smem_exchange1[0] = new_local_abs_max1;
__syncthreads(); __syncthreads();
if(threadIdx.x == 0) if (threadIdx.x == 0)
absmax1[i/BLOCK_SIZE] = new_local_abs_max1; absmax1[i / BLOCK_SIZE] = new_local_abs_max1;
else else
new_local_abs_max1 = smem_exchange1[0]; new_local_abs_max1 = smem_exchange1[0];
// reduce: 2.67/1.69 -> 2.67/1.70 // reduce: 2.67/1.69 -> 2.67/1.70
# pragma unroll N_PER_TH #pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++) for (unsigned int j = 0; j < N_PER_TH; j++) {
{ if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) {
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) switch (OPTIMIZER) {
{
switch(OPTIMIZER)
{
case MOMENTUM: case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); p_vals[j] = ((float)p_vals[j]) - lr * (s1_vals[j]);
break; break;
case LION: case LION:
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
break; break;
case RMSPROP: case RMSPROP:
g_val = g_vals[j]; g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps));
break; break;
case ADAGRAD: case ADAGRAD:
g_val = g_vals[j]; g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps));
break; break;
} }
} }
...@@ -1918,17 +1779,15 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1918,17 +1779,15 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
__syncthreads(); __syncthreads();
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3 // quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH #pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++) for (unsigned int j = 0; j < N_PER_TH; j++) {
{ c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1));
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
// make sure state1 term has still the same sign after quantization // make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values) // (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) {
{ if (s1_vals[j] > 0.0f)
if(s1_vals[j] > 0.0f)
c1s[j] += 1; c1s[j] += 1;
else else
c1s[j] -= 1; c1s[j] -= 1;
...@@ -1945,9 +1804,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1945,9 +1804,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
// Outputs: // Outputs:
// rowStats [rows] // rowStats [rows]
// out [rows, cols] // out [rows, cols]
template<typename T, int THREADS, int SPARSE_DECOMP> template <typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
// Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
...@@ -2009,9 +1868,9 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat ...@@ -2009,9 +1868,9 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
} }
} }
template<typename T, int THREADS, int SPARSE_DECOMP> template <typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols) {
using BlockReduceT = cub::BlockReduce<float, THREADS>; using BlockReduceT = cub::BlockReduce<float, THREADS>;
// One block per row. // One block per row.
...@@ -2049,25 +1908,24 @@ __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshol ...@@ -2049,25 +1908,24 @@ __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshol
} }
} }
template __global__ void kgetRowStats<half, 1024, 0>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); template __global__ void
template __global__ void kgetRowStats<half, 1024, 1>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); kgetRowStats<half, 1024, 0>(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template __global__ void
template __global__ void kInt8VectorQuant<half, 1024, 0>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); kgetRowStats<half, 1024, 1>(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 1>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 0>(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols
);
template __global__ void kInt8VectorQuant<half, 1024, 1>(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols
);
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) #define MM_DEQUANT_CONST 6.200012e-05f // 1.0f/(127.0f*127.0f)
template <int ITEMS_PER_THREAD, int THREADS> template <int ITEMS_PER_THREAD, int THREADS>
__global__ void kdequant_mm_int32_fp16( __global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A, int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
float *__restrict__ const rowStats, half* __restrict__ const bias, const int numRows, const int numCols, const int n
float *__restrict__ const colStats,
half *out,
half *__restrict__ const bias,
const int numRows,
const int numCols,
const int n
) { ) {
const int n_out = numRows * numCols; const int n_out = numRows * numCols;
...@@ -2086,7 +1944,7 @@ __global__ void kdequant_mm_int32_fp16( ...@@ -2086,7 +1944,7 @@ __global__ void kdequant_mm_int32_fp16(
int row_idx, col_idx; int row_idx, col_idx;
#pragma unroll ITEMS_PER_THREAD #pragma unroll ITEMS_PER_THREAD
for (int j = 0; j < ITEMS_PER_THREAD; ++j) { for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
row_idx = (block_offset + thread_offset + j) / numCols; row_idx = (block_offset + thread_offset + j) / numCols;
...@@ -2098,19 +1956,18 @@ __global__ void kdequant_mm_int32_fp16( ...@@ -2098,19 +1956,18 @@ __global__ void kdequant_mm_int32_fp16(
} }
// Each block loads THREADS * ITEMS_PER_THREAD values from A // Each block loads THREADS * ITEMS_PER_THREAD values from A
int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out int valid_items =
? THREADS * ITEMS_PER_THREAD block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset;
: n_out - block_offset;
LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0);
#pragma unroll ITEMS_PER_THREAD #pragma unroll ITEMS_PER_THREAD
for (int j = 0; j < ITEMS_PER_THREAD; ++j) { for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
local_output[j] = __float2half( local_output[j] = __float2half(
fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j])
); );
} }
#pragma unroll ITEMS_PER_THREAD #pragma unroll ITEMS_PER_THREAD
for (int j = 0; j < ITEMS_PER_THREAD; j++) { for (int j = 0; j < ITEMS_PER_THREAD; j++) {
int outIdx = block_offset + thread_offset + j; int outIdx = block_offset + thread_offset + j;
if (outIdx < n_out) { if (outIdx < n_out) {
...@@ -2119,12 +1976,15 @@ __global__ void kdequant_mm_int32_fp16( ...@@ -2119,12 +1976,15 @@ __global__ void kdequant_mm_int32_fp16(
} }
} }
#define DENORM 1.0f/127.0f #define DENORM 1.0f / 127.0f
#define MAX_SPARSE_COUNT 32 #define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256 #define SMEM_SIZE 8 * 256
template <typename T, int SPMM_ITEMS, int BITS> template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) __global__ void kspmm_coo_very_sparse_naive(
{ int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
) {
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
// If a block finishes, the next one is scheduled. Since the last blocks like have fewer // If a block finishes, the next one is scheduled. Since the last blocks like have fewer
...@@ -2139,12 +1999,12 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o ...@@ -2139,12 +1999,12 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
const int count = max_count[blockIdx.x]; const int count = max_count[blockIdx.x];
const int local_max_idx = max_idx[blockIdx.x]; const int local_max_idx = max_idx[blockIdx.x];
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx - 1];
const int local_row_idx = rowidx[offset]; const int local_row_idx = rowidx[offset];
const int warp_id = threadIdx.x / 32; const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32; const int warp_idx = threadIdx.x % 32;
const int warp_offset = (warp_id*32)*SPMM_ITEMS; const int warp_offset = (warp_id * 32) * SPMM_ITEMS;
const int num_items = BITS == 8 ? 8 : 8; const int num_items = BITS == 8 ? 8 : 8;
int idx_col_B = warp_offset; int idx_col_B = warp_offset;
int local_idx_col_B_offset = 0; int local_idx_col_B_offset = 0;
...@@ -2157,10 +2017,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o ...@@ -2157,10 +2017,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
// 128 byte loads per warp == 4 bytes per thread // 128 byte loads per warp == 4 bytes per thread
// 2. Load A into registers // 2. Load A into registers
for(int j = 0; j < MAX_SPARSE_COUNT; j++) for (int j = 0; j < MAX_SPARSE_COUNT; j++) {
{ local_valA[j] = j < count ? values[offset + j] : __float2half(0.0f);
local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); local_colidxA[j] = j < count ? colidx[offset + j] : 0;
local_colidxA[j] = j < count ? colidx[offset+j] : 0;
} }
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
...@@ -2169,124 +2028,119 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o ...@@ -2169,124 +2028,119 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
// added 3 bytes = 6 values between warps should reduce bank conflicts // added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__ half smem_dequant_stats[SMEM_SIZE]; __shared__ half smem_dequant_stats[SMEM_SIZE];
while (idx_col_B < colsB) {
while(idx_col_B < colsB) if (dequant_stats != NULL) {
{ for (int i = threadIdx.x; i < SMEM_SIZE; i += blockDim.x)
if ((idx_col_B + i - local_idx_col_B_offset) < colsB)
if(dequant_stats != NULL) smem_dequant_stats[i] = dequant_stats[idx_col_B + i - local_idx_col_B_offset];
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
__syncthreads(); __syncthreads();
} }
#pragma unroll SPMM_ITEMS #pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j++) for (int j = 0; j < SPMM_ITEMS; j++)
local_valC[j] = 0.0f; local_valC[j] = 0.0f;
#pragma unroll #pragma unroll
for(int i = 0; i < count; i++) for (int i = 0; i < count; i++) {
{
// 3. each warp loads all required rows of B but each warp is offset by k // 3. each warp loads all required rows of B but each warp is offset by k
int row_offset = colsB*local_colidxA[i]; int row_offset = colsB * local_colidxA[i];
#pragma unroll SPMM_ITEMS #pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items) for (int j = 0; j < SPMM_ITEMS; j += num_items) {
{
// 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; int idx = idx_col_B + (warp_idx * SPMM_ITEMS) + j;
if(idx >= colsB){ break; } if (idx >= colsB) {
if((idx+num_items < colsB)) break;
{
if(BITS == 8)
reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
else
reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
} }
if ((idx + num_items < colsB)) {
if (BITS == 8)
reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] =
reinterpret_cast<float2*>(B)[(row_offset + idx) / num_items];
else else
{ reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] =
#pragma unroll num_items reinterpret_cast<float4*>(B)[(row_offset + idx) / num_items];
for(int k = 0; k < num_items; k++) } else {
if(idx+k < colsB) #pragma unroll num_items
local_valsB[k] = B[row_offset+idx+k]; for (int k = 0; k < num_items; k++)
if (idx + k < colsB)
local_valsB[k] = B[row_offset + idx + k];
else else
local_valsB[k] = 0.0f; local_valsB[k] = 0.0f;
} }
#pragma unroll num_items #pragma unroll num_items
for(int k = 0; k < num_items; k++) for (int k = 0; k < num_items; k++) {
{ if (BITS == 8 && dequant_stats != NULL)
if(BITS == 8 && dequant_stats != NULL)
// we do texture cache reads (__ldg) on dequant_stats which should be super fast // we do texture cache reads (__ldg) on dequant_stats which should be super fast
{ {
float valB = local_valsB[k]; float valB = local_valsB[k];
float valA = local_valA[i]; float valA = local_valA[i];
if(valB != 0.0 && valA != 0.0) if (valB != 0.0 && valA != 0.0)
local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; local_valC[j + k] =
} (float)local_valC[j + k] +
else ((float)smem_dequant_stats[idx + k - local_idx_col_B_offset]) * DENORM * valB * valA;
local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; } else
local_valC[j + k] = (float)local_valC[j + k] + (float)local_valsB[k] * (float)local_valA[i];
} }
} }
} }
int idx_row_C = (colsB*local_row_idx); int idx_row_C = (colsB * local_row_idx);
#pragma unroll SPMM_ITEMS #pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items) for (int j = 0; j < SPMM_ITEMS; j += num_items) {
{ // int idx_col_C = idx_col_B + (32*j) + warp_idx;
//int idx_col_C = idx_col_B + (32*j) + warp_idx; int idx_col_C = idx_col_B + warp_idx * SPMM_ITEMS + j;
int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
int idx_val = idx_col_C + idx_row_C; int idx_val = idx_col_C + idx_row_C;
if(idx_col_C +num_items < colsB) if (idx_col_C + num_items < colsB) {
{
// load outputs to do inplace addition // load outputs to do inplace addition
reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items]; reinterpret_cast<float4(&)[num_items / 4]>(local_valOut)[0] =
reinterpret_cast<float4*>(out)[idx_val / num_items];
#pragma unroll num_items #pragma unroll num_items
for(int k = 0; k < num_items; k++) for (int k = 0; k < num_items; k++)
local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; local_valC[(j / num_items) + k] = (float)local_valC[(j / num_items) + k] + (float)local_valOut[k];
reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items]; reinterpret_cast<float4*>(out)[idx_val / num_items] =
} reinterpret_cast<float4(&)[num_items]>(local_valC)[j / num_items];
else } else {
{ #pragma unroll num_items
#pragma unroll num_items for (int k = 0; k < num_items; k++)
for(int k = 0; k < num_items; k++) if (idx_col_C + k < colsB)
if(idx_col_C + k < colsB) out[idx_val + k] = (float)out[idx_val + k] + (float)local_valC[j + k];
out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
} }
} }
idx_col_B += blockDim.x*SPMM_ITEMS; idx_col_B += blockDim.x * SPMM_ITEMS;
local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; local_idx_col_B_offset += blockDim.x * SPMM_ITEMS;
} }
} }
#define WARPS 3 #define WARPS 3
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{ template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc) {
#if __CUDA_ARCH__ >= 750 #if __CUDA_ARCH__ >= 750
using namespace nvcuda; using namespace nvcuda;
int col_offset = blockIdx.x *32; int col_offset = blockIdx.x * 32;
const int warp_id = threadIdx.x / 32; const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16; const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16; const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2; const int batch_size_warps = (WARPS - 1) * 2;
const int val_per_iter = blockDim.x-32; const int val_per_iter = blockDim.x - 32;
T local_A[4]; T local_A[4];
T local_B[128]; T local_B[128];
const int a_tile_offset = 16; const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16); const int b_tile_offset = (16 * 32 + 16);
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_A[8 * 16 + (2 * 16 * (batch_size_warps - 1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))];
//__shared__ T smem_C[8*32]; //__shared__ T smem_C[8*32];
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag; wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
...@@ -2298,194 +2152,177 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, ...@@ -2298,194 +2152,177 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
int idx = 0 + threadIdx.x; int idx = 0 + threadIdx.x;
int loaded_values = 0; int loaded_values = 0;
// prefetch // prefetch
if(idx < K && warp_id < (WARPS-1)) if (idx < K && warp_id < (WARPS - 1)) {
{ if (loaded_values == 0) {
if(loaded_values == 0)
{
local_A[0] = A[idx]; local_A[0] = A[idx];
local_A[1] = A[idx+(1*val_per_iter)]; local_A[1] = A[idx + (1 * val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)]; local_A[2] = A[idx + (2 * val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)]; local_A[3] = A[idx + (3 * val_per_iter)];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++) {
{ local_B[col] = B[(col_offset + col) * ldb + idx];
local_B[col] = B[(col_offset+col)*ldb+idx]; local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
} }
loaded_values = 3; loaded_values = 3;
} } else {
else
{
if(loaded_values == 3) if (loaded_values == 3) {
{
local_A[0] = local_A[1]; local_A[0] = local_A[1];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)]; local_B[col] = local_B[col + (32)];
} } else if (loaded_values == 2) {
else if(loaded_values == 2)
{
local_A[0] = local_A[2]; local_A[0] = local_A[2];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)]; local_B[col] = local_B[col + (64)];
} } else {
else
{
local_A[0] = local_A[3]; local_A[0] = local_A[3];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)]; local_B[col] = local_B[col + (96)];
} }
loaded_values--; loaded_values--;
} }
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
} local_B[col];
else if(warp_id < (WARPS-1)) } else if (warp_id < (WARPS - 1)) {
{
local_A[0] = T(0.0); local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = 0.0f; local_B[col] = 0.0f;
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
} }
ticktock = ticktock == 0 ? 1 : 0; ticktock = ticktock == 0 ? 1 : 0;
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) {
{
idx = base_idx + threadIdx.x; idx = base_idx + threadIdx.x;
__syncthreads(); __syncthreads();
if(idx < K && warp_id < (WARPS-1)) if (idx < K && warp_id < (WARPS - 1)) {
{ // local_A[0] = A[idx];
//local_A[0] = A[idx];
//#pragma unroll 32 // #pragma unroll 32
//for(int col = 0; col < 32; col++) // for(int col = 0; col < 32; col++)
// local_B[col] = B[(col_offset+col)*ldb+idx]; // local_B[col] = B[(col_offset+col)*ldb+idx];
if(loaded_values == 0) if (loaded_values == 0) {
{
local_A[0] = A[idx]; local_A[0] = A[idx];
local_A[1] = A[idx+(1*val_per_iter)]; local_A[1] = A[idx + (1 * val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)]; local_A[2] = A[idx + (2 * val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)]; local_A[3] = A[idx + (3 * val_per_iter)];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++) {
{ local_B[col] = B[(col_offset + col) * ldb + idx];
local_B[col] = B[(col_offset+col)*ldb+idx]; local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
} }
loaded_values = 3; loaded_values = 3;
} } else {
else
{
if(loaded_values == 3) if (loaded_values == 3) {
{
local_A[0] = local_A[1]; local_A[0] = local_A[1];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)]; local_B[col] = local_B[col + (32)];
} } else if (loaded_values == 2) {
else if(loaded_values == 2)
{
local_A[0] = local_A[2]; local_A[0] = local_A[2];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)]; local_B[col] = local_B[col + (64)];
} } else {
else
{
local_A[0] = local_A[3]; local_A[0] = local_A[3];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)]; local_B[col] = local_B[col + (96)];
} }
loaded_values--; loaded_values--;
} }
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
} local_B[col];
else if(warp_id < (WARPS-1)) } else if (warp_id < (WARPS - 1)) {
{
local_A[0] = T(0.0); local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = 0.0f; local_B[col] = 0.0f;
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
} }
ticktock = ticktock == 0 ? 1 : 0; ticktock = ticktock == 0 ? 1 : 0;
if(warp_id == (WARPS-1)) if (warp_id == (WARPS - 1))
for(int k = 0; k < batch_size_warps; k++) for (int k = 0; k < batch_size_warps; k++) {
{ wmma::load_matrix_sync(
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu ); // 111 mu
wmma::load_matrix_sync(
b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16
); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
} }
} }
__syncthreads(); __syncthreads();
if(warp_id != (WARPS-1)){ return; } if (warp_id != (WARPS - 1)) {
return;
}
// only warp_id == (WARPS-1) from here // only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32; int warp_lane = threadIdx.x % 32;
ticktock = ticktock == 0 ? 1 : 0; ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++) for (int k = 0; k < batch_size_warps; k++) {
{ wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
} }
// 129 mu // 129 mu
if(warp_id == (WARPS-1)) if (warp_id == (WARPS - 1))
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
if(col_offset + warp_lane < M) if (col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane]; out[col_offset + warp_lane] = smem_A[warp_lane];
#endif #endif
} }
template <typename T> __device__ void printnonzero(T* A, int num_values, const char* strval) {
template <typename T> __device__ void printnonzero(T *A, int num_values, const char * strval) for (int i = 0; i < num_values; i++)
{ if ((float)A[i] != 0.0)
for(int i = 0; i < num_values; i++)
if((float)A[i] != 0.0)
printf("%s %i %f\n", strval, i, (float)A[i]); printf("%s %i %f\n", strval, i, (float)A[i]);
} }
template <typename T, int THREADS>
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) __global__ void kgemm_4bit_inference(
{ int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
) {
//// element-wise kernel //// element-wise kernel
//// 1. Load batch x k into registers //// 1. Load batch x k into registers
...@@ -2507,17 +2344,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i ...@@ -2507,17 +2344,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
//// 9. write outputs to matmul output matrix //// 9. write outputs to matmul output matrix
#if __CUDA_ARCH__ >= 750 #if __CUDA_ARCH__ >= 750
using namespace nvcuda; using namespace nvcuda;
int col_offset = blockIdx.x *32; int col_offset = blockIdx.x * 32;
const int warp_id = threadIdx.x / 32; const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32; const int warp_idx = threadIdx.x % 32;
const int half_warp_id = threadIdx.x / 16; const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16; const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2; const int batch_size_warps = (WARPS - 1) * 2;
T quant_map[16]; T quant_map[16];
#pragma unroll 16 #pragma unroll 16
for(int i = 0; i < 16; i++) for (int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i]; quant_map[i] = nf4_data[i];
//__shared__ T quant_map[16*160]; //__shared__ T quant_map[16*160];
...@@ -2525,20 +2362,19 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i ...@@ -2525,20 +2362,19 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
T local_B[64]; T local_B[64];
unsigned char local_B_4bit[32]; unsigned char local_B_4bit[32];
const int a_tile_offset = 16; const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16); const int b_tile_offset = (16 * 32 + 16);
__shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; __shared__ T smem_A[8 * 16 + (16 * (batch_size_warps - 1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))];
__shared__ T smem_C[8*32]; __shared__ T smem_C[8 * 32];
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag; wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag; wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag; wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f); wmma::fill_fragment(c_frag, 0.0f);
for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) for (int i = threadIdx.x; i < (8 * 32); i += blockDim.x)
smem_C[i] = 0.0f; smem_C[i] = 0.0f;
__syncthreads(); __syncthreads();
...@@ -2547,305 +2383,287 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i ...@@ -2547,305 +2383,287 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
int idx = 0 + threadIdx.x; int idx = 0 + threadIdx.x;
int loaded_values = 0; int loaded_values = 0;
// prefetch // prefetch
if(idx < K && warp_id < (WARPS-1)) if (idx < K && warp_id < (WARPS - 1)) {
{ if (loaded_values == 0) {
if(loaded_values == 0)
{
local_A[0] = A[idx]; local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32]; local_A[1] = A[idx + blockDim.x - 32];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; local_B_4bit[col] = B[(col_offset + col) * ldb + idx];
loaded_values = 1; loaded_values = 1;
} } else {
else
{
local_A[0] = local_A[1]; local_A[0] = local_A[1];
loaded_values--; loaded_values--;
#pragma unroll 64 #pragma unroll 64
for(int col = 0; col < 64; col+=2) for (int col = 0; col < 64; col += 2) {
{ // local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); // local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); // local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); // local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); // local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); // local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); // local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); // local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); local_B[col] = quant_map[160 * (local_B_4bit[col / 2] >> 4) + warp_idx] * T(17.0);
local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); local_B[col + 1] = quant_map[160 * (local_B_4bit[col / 2] & 0x0F) + warp_idx] * T(17.0);
} }
} }
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
} local_B[col];
else if(warp_id < (WARPS-1)) } else if (warp_id < (WARPS - 1)) {
{
local_A[0] = T(0.0); local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = 0.0f; local_B[col] = 0.0f;
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
} }
ticktock = ticktock == 0 ? 1 : 0; ticktock = ticktock == 0 ? 1 : 0;
//if(threadIdx.x == 0) // if(threadIdx.x == 0)
//printf("aa %i %i\n", idx, loaded_values); // printf("aa %i %i\n", idx, loaded_values);
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) {
{
idx = base_idx + threadIdx.x; idx = base_idx + threadIdx.x;
//if(threadIdx.x == 0) // if(threadIdx.x == 0)
//printf("%i %i\n", idx, loaded_values); // printf("%i %i\n", idx, loaded_values);
//__syncthreads(); //__syncthreads();
if(idx < K && warp_id < (WARPS-1)) if (idx < K && warp_id < (WARPS - 1)) {
{ if (loaded_values == 0) {
if(loaded_values == 0)
{
local_A[0] = A[idx]; local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32]; local_A[1] = A[idx + blockDim.x - 32];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++) {
{ local_B_4bit[col] = B[(col_offset + col) * ldb + idx];
local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; local_B_4bit[col + 16] = B[(col_offset + col) * ldb + idx];
local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx];
} }
loaded_values = 1; loaded_values = 1;
} } else {
else
{
local_A[0] = local_A[1]; local_A[0] = local_A[1];
loaded_values--; loaded_values--;
int absidx = (idx + col_offset)/blocksize; int absidx = (idx + col_offset) / blocksize;
half local_absmax = __ldg(&(absmax[absidx])); half local_absmax = __ldg(&(absmax[absidx]));
#pragma unroll 64 #pragma unroll 64
for(int col = 0; col < 64; col+=2) for (int col = 0; col < 64; col += 2) {
{ // local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); // local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); // local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); // local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); // local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); // local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); local_B[col] = quant_map[(local_B_4bit[col / 2] >> 4)] * T(absidx);
local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); local_B[col + 1] = quant_map[(local_B_4bit[col / 2] & 0x0F)] * T(absidx);
} }
//printnonzero<T>(local_B, 128, ""); // printnonzero<T>(local_B, 128, "");
} }
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
} local_B[col];
else if(warp_id < (WARPS-1)) } else if (warp_id < (WARPS - 1)) {
{
local_A[0] = T(0.0); local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
local_B[col] = 0.0f; local_B[col] = 0.0f;
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
} }
ticktock = ticktock == 0 ? 1 : 0; ticktock = ticktock == 0 ? 1 : 0;
if(warp_id == (WARPS-1)) if (warp_id == (WARPS - 1))
for(int k = 0; k < batch_size_warps; k++) for (int k = 0; k < batch_size_warps; k++) {
{ wmma::load_matrix_sync(
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu ); // 111 mu
wmma::load_matrix_sync(
b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16
); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
} }
} }
__syncthreads(); __syncthreads();
//if(threadIdx.x == 0) // if(threadIdx.x == 0)
//{ //{
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); // printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); // printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
//} // }
if(warp_id != (WARPS-1)){ return; } if (warp_id != (WARPS - 1)) {
return;
}
// only warp_id == (WARPS-1) from here // only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32; int warp_lane = threadIdx.x % 32;
ticktock = ticktock == 0 ? 1 : 0; ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++) for (int k = 0; k < batch_size_warps; k++) {
{ // if(warp_lane == 0)
//if(warp_lane == 0) // printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
} }
// 129 mu // 129 mu
if(warp_id == (WARPS-1)) if (warp_id == (WARPS - 1))
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
//printnonzero<T>(smem_C, 32, ""); // printnonzero<T>(smem_C, 32, "");
if(col_offset + warp_lane < M) if (col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_C[warp_lane]; out[col_offset + warp_lane] = smem_C[warp_lane];
#endif #endif
} }
#define num_values_4bit 32 #define num_values_4bit 32
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
{ template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
int lda, int ldb, int ldc, int blocksize
) {
// per threadblock: // per threadblock:
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// 4 warps -> 4 loads per iter // 4 warps -> 4 loads per iter
// 1x32 * 32x4 -> 1x4 outputs per thread block // 1x32 * 32x4 -> 1x4 outputs per thread block
typedef cub::WarpReduce<float> WarpReduce; typedef cub::WarpReduce<float> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32];
const int warp_idx = threadIdx.x / 32; const int warp_idx = threadIdx.x / 32;
const int warp_lane = threadIdx.x % 32; const int warp_lane = threadIdx.x % 32;
const int row_B = (THREADS/32)*blockIdx.x + warp_idx; const int row_B = (THREADS / 32) * blockIdx.x + warp_idx;
const int offset_B = ldb*row_B; const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit/2; const int num_values_8bit = num_values_4bit / 2;
float local_C = 0.0f; float local_C = 0.0f;
unsigned char local_B_4bit[num_values_8bit]; unsigned char local_B_4bit[num_values_8bit];
T local_B[num_values_4bit/4]; T local_B[num_values_4bit / 4];
T local_A[num_values_4bit/4]; T local_A[num_values_4bit / 4];
__shared__ T quant_map[16]; __shared__ T quant_map[16];
T local_absmax = T(0.0f); T local_absmax = T(0.0f);
if (threadIdx.x < 16) if (threadIdx.x < 16)
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x]));
//for(int i = threadIdx.x; i < 16; i++) // for(int i = threadIdx.x; i < 16; i++)
//quant_map[i] = T(__ldg(&datatype[i])); // quant_map[i] = T(__ldg(&datatype[i]));
__syncthreads(); __syncthreads();
// A: [1, K] // A: [1, K]
// B: [N, K] // B: [N, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
{ const int inner_idx_halved = inner_idx / 2;
const int inner_idx_halved = inner_idx/2;
// Since blocksize will always be a power-of-2, we avoid more expensive // Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation. // division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); const int absidx = ((2 * offset_B) + inner_idx) >> (31 - __clz(blocksize));
local_absmax = __ldg(&(absmax[absidx])); local_absmax = __ldg(&(absmax[absidx]));
if(row_B < M) if (row_B < M) {
{ if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
if((inner_idx_halved + num_values_8bit) < (K/2))
{
// this is the most important for performance considerations // this is the most important for performance considerations
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] =
} reinterpret_cast<int4*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
else } else {
{ #pragma unroll
#pragma unroll for (int j = 0; j < (num_values_8bit); j++)
for(int j = 0; j < (num_values_8bit); j++) if ((inner_idx_halved) + j < (K / 2))
if((inner_idx_halved) + j < (K/2)) local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
else else
local_B_4bit[j] = 0b01110111; local_B_4bit[j] = 0b01110111;
} }
} } else {
else #pragma unroll
{ for (int j = 0; j < (num_values_8bit); j++)
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
local_B_4bit[j] = 0b01110111; local_B_4bit[j] = 0b01110111;
} }
for(int i = 0; i < 4; i++) for (int i = 0; i < 4; i++) {
{ #pragma unroll
#pragma unroll for (int k = 0; k < num_values_8bit / 4; k++) {
for(int k = 0; k < num_values_8bit/4; k++) #if BNB_BF16_AVAILABLE
{ local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
#if BNB_BF16_AVAILABLE local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; #else
local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax;
#else
// bf16 multipliation not supported // bf16 multipliation not supported
local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); local_B[k * 2] =
local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * (float)local_absmax);
#endif local_B[k * 2 + 1] =
T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * (float)local_absmax);
#endif
} }
if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
{
// this is also relatively important for performance // this is also relatively important for performance
if(BITS==16) if (BITS == 16) {
{ reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + i]; reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 4) + i];
} } else {
else reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
{ reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] =
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
} }
} } else
else #pragma unroll
#pragma unroll for (int k = 0; k < num_values_4bit / 4; k++)
for(int k = 0; k < num_values_4bit/4; k++) if (inner_idx + (i * num_values_4bit / 4) + k < K)
if(inner_idx + (i*num_values_4bit/4) + k < K) local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)];
else else
local_A[k] = T(0.0f); local_A[k] = T(0.0f);
// accumulate in float; small performance hit for Ampere, but lower error for outputs
// accumulate in float; small performance hit for Ampere, but lower error for outputs #pragma unroll
#pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) {
for(int k = 0; k < num_values_4bit/4; k++) #if BNB_BF16_AVAILABLE
{ local_C += (float)(local_A[k] * local_B[k]);
#if BNB_BF16_AVAILABLE #else
local_C += (float)(local_A[k]*local_B[k]);
#else
// bf16 multipliation not supported // bf16 multipliation not supported
local_C += ((float)local_A[k]*(float)local_B[k]); local_C += ((float)local_A[k] * (float)local_B[k]);
#endif #endif
} }
} }
} }
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
if(row_B < M && warp_lane == 0) if (row_B < M && warp_lane == 0)
out[row_B] = T(local_C); out[row_B] = T(local_C);
} }
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n) template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n) {
{ for (long i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += (blockDim.x * gridDim.x)) {
for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) switch (FUNC) {
{
switch(FUNC)
{
case FILL: case FILL:
A[i] = (T)value; A[i] = (T)value;
break; break;
...@@ -2853,70 +2671,143 @@ template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long ...@@ -2853,70 +2671,143 @@ template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long
A[i] = (T)i; A[i] = (T)i;
break; break;
case _MUL: case _MUL:
A[i] = A[i]*B[i]; A[i] = A[i] * B[i];
break; break;
} }
} }
} }
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================
template __global__ void kfunc<float, FILL>(float *A, float *B, float value, long n); template __global__ void kfunc<float, FILL>(float* A, float* B, float value, long n);
template __global__ void kfunc<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n); template __global__ void kfunc<unsigned char, FILL>(unsigned char* A, unsigned char* B, unsigned char value, long n);
template __global__ void kfunc<float, ARANGE>(float *A, float *B, float value, long n); template __global__ void kfunc<float, ARANGE>(float* A, float* B, float value, long n);
template __global__ void kfunc<float, _MUL>(float *A, float *B, float value, long n); template __global__ void kfunc<float, _MUL>(float* A, float* B, float value, long n);
// these are not used and make no sense, but the compiler needs them // these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); // template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B,
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 32, 256>(
template __global__ void gemm_device<half, 32, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); );
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 32, 192>(
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); );
template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 32, 160>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 128>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
// template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 64>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 96>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
// these are not used and make no sense, but the compiler needs them // these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); // template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B,
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 16, 256>(
template __global__ void gemm_device<half, 16, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); );
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 16, 192>(
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); );
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 16, 160>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
template __global__ void kgemm_4bit_inference<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); );
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void gemm_device<half, 16, 128>(
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); );
// template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B,
template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); // float * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void gemm_device<half, 16, 32>(
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void gemm_device<half, 16, 64>(
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); );
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void gemm_device<half, 16, 96>(
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); );
template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __global__ void kgemm_4bit_inference<half, 96>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 128>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 160>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 256>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out,
int lda, int ldb, int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(
int M, int N, int K, __nv_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
__nv_bfloat16* out, int lda, int ldb, int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(
int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
float* out, int lda, int ldb, int ldc, int blocksize
);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template __global__ void kdequant_mm_int32_fp16<4, 512>(
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>( \
float* state1, float *unorm, \ gtype * g, gtype * p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, \
const float beta1, const float beta2, const float eps, const float weight_decay, \ const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n \
const int step, const float lr, const float gnorm_scale, const int n); \ );
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
...@@ -2932,8 +2823,11 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) ...@@ -2932,8 +2823,11 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
#define MAKE_Optimizer32bit1State(oname, gtype) \ #define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ template __global__ void kOptimizer32bit1State<gtype, oname>( \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ gtype * g, gtype * p, float* state1, float* unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, const int step, \
const float lr, const float gnorm_scale, const bool skip_zeros, const int n \
);
MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(MOMENTUM, float)
...@@ -2949,10 +2843,11 @@ MAKE_Optimizer32bit1State(ADAGRAD, float) ...@@ -2949,10 +2843,11 @@ MAKE_Optimizer32bit1State(ADAGRAD, float)
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16) MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>( \
float* state1, float* state2, float *unorm, \ gtype * g, gtype * p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, \
const float beta1, const float beta2, const float eps, const float weight_decay, \ const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, \
const int step, const float lr, const float gnorm_scale, const int n); \ const int n \
);
MAKE_PreconditionOptimizer32bit2State(ADAM, float) MAKE_PreconditionOptimizer32bit2State(ADAM, float)
MAKE_PreconditionOptimizer32bit2State(ADAM, half) MAKE_PreconditionOptimizer32bit2State(ADAM, half)
...@@ -2961,31 +2856,49 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) ...@@ -2961,31 +2856,49 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16)
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, template __global__ void kOptimizer32bit2State<float, ADAM>(
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const int n
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); );
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, template __global__ void kOptimizer32bit2State<half, ADAM>(
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const int n
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); );
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(
__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(
float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(
half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(
__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ #define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>( \
float *unorm, \ gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, \
const float beta1, \ const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, \
const float beta2, \ float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n \
const float eps, const int step, \ );
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
const float weight_decay, \
const float gnorm_scale, \
const int n); \
MAKE_PreconditionStatic8bit1State(MOMENTUM, half) MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float) MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
...@@ -2997,16 +2910,12 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half) ...@@ -2997,16 +2910,12 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
MAKE_PreconditionStatic8bit1State(ADAGRAD, float) MAKE_PreconditionStatic8bit1State(ADAGRAD, float)
#define MAKE_optimizerStatic8bit1State(oname, gtype) \ #define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \ template __global__ void kOptimizerStatic8bit1State<gtype, oname>( \
const float *unorm, const float max_unorm, const float param_norm, \ gtype * p, gtype* const g, unsigned char* state1, const float* unorm, const float max_unorm, \
const float beta1, \ const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, \
const float beta2, \ float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, \
const float eps, const int step, const float lr, \ const float gnorm_scale, const int n \
float* __restrict__ const quantiles1, \ );
float* max1, float* new_max1, \
float weight_decay, \
const float gnorm_scale, \
const int n); \
MAKE_optimizerStatic8bit1State(MOMENTUM, half) MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float) MAKE_optimizerStatic8bit1State(MOMENTUM, float)
...@@ -3017,39 +2926,39 @@ MAKE_optimizerStatic8bit1State(LION, float) ...@@ -3017,39 +2926,39 @@ MAKE_optimizerStatic8bit1State(LION, float)
MAKE_optimizerStatic8bit1State(ADAGRAD, half) MAKE_optimizerStatic8bit1State(ADAGRAD, half)
MAKE_optimizerStatic8bit1State(ADAGRAD, float) MAKE_optimizerStatic8bit1State(ADAGRAD, float)
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ #define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>( \
float *unorm, \ gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, \
const float beta1, const float beta2, \ unsigned char* __restrict__ const state2, float* unorm, const float beta1, const float beta2, const float eps, \
const float eps, const int step, \ const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n \
float* max1, float* max2, float* new_max1, float* new_max2, \ );
const float gnorm_scale, \
const int n); \
MAKE_PreconditionStatic8bit2State(ADAM, half) MAKE_PreconditionStatic8bit2State(ADAM, half)
MAKE_PreconditionStatic8bit2State(ADAM, float) MAKE_PreconditionStatic8bit2State(ADAM, float)
#define MAKE_optimizerStatic8bit2State(oname, gtype) \ #define MAKE_optimizerStatic8bit2State(oname, gtype) \
template __global__ void kOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ template __global__ void kOptimizerStatic8bit2State<gtype, oname>( \
const float *unorm, const float max_unorm, const float param_norm, \ gtype * p, gtype* const g, unsigned char* state1, unsigned char* state2, const float* unorm, \
const float beta1, const float beta2, \ const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, \
const float eps, const int step, const float lr, \ const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, \
float* max1, float* max2, float* new_max1, float* new_max2, \ const int n \
float weight_decay, \ );
const float gnorm_scale, \
const int n); \
MAKE_optimizerStatic8bit2State(ADAM, half) MAKE_optimizerStatic8bit2State(ADAM, half)
MAKE_optimizerStatic8bit2State(ADAM, float) MAKE_optimizerStatic8bit2State(ADAM, float)
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n); kPercentileClipping<float, 2048, 4>(float* __restrict__ g, float* gnorm_vec, int step, const int n);
template __global__ void
kPercentileClipping<half, 2048, 4>(half* __restrict__ g, float* gnorm_vec, int step, const int n);
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>( \
float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \
const int rand_offset, const int n \
);
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit)
...@@ -3119,24 +3028,41 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) ...@@ -3119,24 +3028,41 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4)
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); );
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); );
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); );
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(
float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(
float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n
);
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>( \
const float beta1, const float beta2, const float beta3, const float alpha, \ gtype * p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, \
const float eps, const int step, const float lr, \ const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, \
float* absmax1, float* absmax2, \ float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \
float weight_decay, \ );
const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1)
...@@ -3146,14 +3072,11 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) ...@@ -3146,14 +3072,11 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ gtype * p, gtype* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, \
const float beta1, const float beta2, \ const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, \
const float eps, const int step, const float lr, \ float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \
float* __restrict__ const quantiles1, \ );
float* absmax1, \
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1)
...@@ -3168,5 +3091,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) ...@@ -3168,5 +3091,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval); template __device__ void printnonzero<float>(float* A, int num_values, const char* strval);
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval); template __device__ void printnonzero<half>(half* A, int num_values, const char* strval);
...@@ -9,116 +9,129 @@ ...@@ -9,116 +9,129 @@
#ifndef kernels #ifndef kernels
#define kernels #define kernels
__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n);
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kDequantize(float* code, unsigned char* A, float* out, const int n);
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template <typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); __global__ void kQuantizeBlockwise(
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> );
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha,
const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
const float weight_decay,
const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
float weight_decay, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void __global__ void
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
const float eps, const int step, const float lr, __global__ void kPreconditionOptimizer32bit2State(
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,
float* max1, float* max2, float* new_max1, float* new_max2, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
float weight_decay, const float gnorm_scale, const int n); );
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise( template <typename T, int OPTIMIZER>
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, __global__ void kOptimizer32bit2State(
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise( );
T* p, T* __restrict__ const g, unsigned char* state1,
const float beta1, const float beta2, template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
const float eps, const int step, const float lr, __global__ void kPreconditionOptimizer32bit1State(
float* __restrict__ const quantiles1, T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps,
float* absmax1, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
float weight_decay, );
const float gnorm_scale, const bool skip_zeros, const int n);
template <typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1,
const float beta2, const float eps, const float weight_decay, const int step, const float lr,
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); const float gnorm_scale, const bool skip_zeros, const int n
);
template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, template <typename T, int OPTIMIZER>
half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); __global__ void kPreconditionOptimizerStatic8bit1State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1,
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1,
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); float* new_max1, const float weight_decay, const float gnorm_scale, const int n
);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template <typename T, int OPTIMIZER>
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); __global__ void kOptimizerStatic8bit1State(
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm,
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale,
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n); const int n
);
template <typename T, int OPTIMIZER>
__global__ void kPreconditionOptimizerStatic8bit2State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2,
float* unorm, const float beta1, const float beta2, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizerStatic8bit2State(
T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2,
const float beta3, const float alpha, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps,
const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n);
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template <int ITEMS_PER_THREAD, int THREADS>
__global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);
template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
);
template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
int lda, int ldb, int ldc, int blocksize
);
template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n);
#endif #endif
...@@ -5,37 +5,34 @@ ...@@ -5,37 +5,34 @@
#define NUM 4 #define NUM 4
#define NUM_BLOCK 4096 #define NUM_BLOCK 4096
static inline MPSGraph* get_graph() static inline MPSGraph* get_graph() {
{
static MPSGraph* cur = nil; static MPSGraph* cur = nil;
if(!cur) { if (!cur) {
cur = [[MPSGraph alloc] init]; cur = [[MPSGraph alloc] init];
} }
return cur; return cur;
} }
static inline id<MTLDevice> get_device() static inline id<MTLDevice> get_device() {
{ NSError* error = nil;
NSError *error = nil;
static id<MTLDevice> device = nil; static id<MTLDevice> device = nil;
if(!device) { if (!device) {
device = MTLCreateSystemDefaultDevice(); device = MTLCreateSystemDefaultDevice();
} }
if(!device) { if (!device) {
NSLog(@"Failed to get MPS device"); NSLog(@"Failed to get MPS device");
abort(); abort();
} }
return device; return device;
} }
static inline id<MTLLibrary> get_library() static inline id<MTLLibrary> get_library() {
{ NSError* error = nil;
NSError *error = nil;
static id<MTLLibrary> library = nil; static id<MTLLibrary> library = nil;
if(!library) { if (!library) {
library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error];
} }
if(!library) { if (!library) {
NSLog(@"Failed to load bitsandbytes.metallib"); NSLog(@"Failed to load bitsandbytes.metallib");
abort(); abort();
} }
...@@ -44,20 +41,18 @@ static inline id<MTLLibrary> get_library() ...@@ -44,20 +41,18 @@ static inline id<MTLLibrary> get_library()
/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) /*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n)
{ {
id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"]; id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0
return out; dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out;
}*/ }*/
// MPSGraph function for quantize // MPSGraph function for quantize
extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) {
{
id<MTLDevice> device = get_device(); id<MTLDevice> device = get_device();
id<MTLLibrary> library = get_library(); id<MTLLibrary> library = get_library();
static id<MTLFunction> kernel = nil; static id<MTLFunction> kernel = nil;
if(!kernel) { if (!kernel) {
kernel = [library newFunctionWithName:@"quantize"]; kernel = [library newFunctionWithName:@"quantize"];
if(!kernel) { if (!kernel) {
NSLog(@"Failed to load bitsandbytes.metallib"); NSLog(@"Failed to load bitsandbytes.metallib");
abort(); abort();
} }
......
...@@ -3,170 +3,190 @@ ...@@ -3,170 +3,190 @@
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#include <ops.cuh>
#include <kernels.cuh>
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h> #include <BinSearch.h>
#include <cassert> #include <cassert>
#include <common.h> #include <common.h>
#include <cub/device/device_scan.cuh>
#include <kernels.cuh>
#include <limits>
#include <ops.cuh>
#define ERR_NOT_IMPLEMENTED 100 #define ERR_NOT_IMPLEMENTED 100
using namespace BinSearch; using namespace BinSearch;
using std::cout; using std::cout;
using std::endl; using std::endl;
void quantize(float* code, float* A, unsigned char* out, int n) {
void quantize(float *code, float *A, unsigned char *out, int n) int num_blocks = n / 1024;
{
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kQuantize<<<num_blocks, 1024>>>(code, A, out, n); kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream) void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
{ int num_blocks = n / 1024;
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n); kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) template <typename T, int STOCHASTIC, int DATA_TYPE>
{ void quantizeBlockwise(
int num_blocks = n/blocksize; float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
) {
int num_blocks = n / blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
if(blocksize == 4096) if (blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE>
else if(blocksize == 2048) <<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024) else if (blocksize == 1024)
kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512) else if (blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 256) else if (blocksize == 256)
kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128) else if (blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64) else if (blocksize == 64)
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream) template <typename T, int DATA_TYPE>
{ void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, cudaStream_t stream
) {
// printf("stream==%d\n",stream); // printf("stream==%d\n",stream);
int num_blocks = n/blocksize; int num_blocks = n / blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
int tile_size = (DATA_TYPE > 0) ? 1024 : 512; int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
if(DATA_TYPE > 0) if (DATA_TYPE > 0)
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n); kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n);
else else
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n); kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int OPTIMIZER>
void optimizer32bit(
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, const float beta1,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const float lr, const float gnorm_scale, bool skip_zeros, const int n
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) ) {
{ int num_blocks = n / 4096;
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
switch(OPTIMIZER) switch (OPTIMIZER) {
{
case ADAM: case ADAM:
case ADEMAMIX: case ADEMAMIX:
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr,
gnorm_scale, skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8>
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); <<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale,
skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case LION: case LION:
// in lion, the momentum update after the parameter update // in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale,
skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8>
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); <<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
break; break;
} }
} }
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, template <typename T, int OPTIMIZER>
unsigned char* state1, unsigned char* state2, void optimizerStatic8bit(
float *unorm, float max_unorm, float param_norm, T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm,
float beta1, float beta2, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1,
float eps, int step, float lr, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n
float* quantiles1, float* quantiles2, ) {
float* max1, float* max2, float* new_max1, float* new_max2, int num_blocks = n / 4096;
float weight_decay,
const float gnorm_scale, int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } if (max_unorm > 0.0f) {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
}
switch(OPTIMIZER) switch (OPTIMIZER) {
{
case ADAM: case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(
p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1,
new_max2, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2,
max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(
p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1,
weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case LION: case LION:
// in lion, the momentum update happens after the parameter update // in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1,
weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(
p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
default: default:
...@@ -179,39 +199,23 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -179,39 +199,23 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
#define BLOCKSIZE_1STATE 256 #define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1 #define NUM_1STATE 1
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise( template <typename T, int OPTIMIZER>
T* p, void optimizerStatic8bitBlockwise(
T* g, T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
unsigned char* state1, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2,
unsigned char* state2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n
float beta1,
float beta2,
float beta3,
float alpha,
float eps,
int step,
float lr,
float* quantiles1,
float* quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
bool skip_zeros,
int n
) { ) {
int num_blocks = 0; int num_blocks = 0;
switch(OPTIMIZER) switch (OPTIMIZER) {
{
case ADAM: case ADAM:
case ADEMAMIX: case ADEMAMIX:
num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n / BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>( kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE>
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, <<<num_blocks, BLOCKSIZE_2STATE / NUM_2STATE>>>(
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1,
skip_zeros, n absmax2, weight_decay, gnorm_scale, skip_zeros, n
); );
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
...@@ -219,88 +223,76 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise( ...@@ -219,88 +223,76 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
case LION: case LION:
num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n / BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE>
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); <<<num_blocks, BLOCKSIZE_1STATE / NUM_1STATE>>>(
p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
} }
} }
template <typename T> void percentileClipping(T* g, float* gnorm_vec, int step, const int n) {
int num_blocks = n / 2048;
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
int num_blocks = n/2048;
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float)));
kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n); kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) void gemmex(
{ Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
) {
const int falpha = 1; const int falpha = 1;
const int fbeta = 0; const int fbeta = 0;
const void * alpha = &falpha; const void* alpha = &falpha;
const void * beta = &fbeta; const void* beta = &fbeta;
cublasStatus_t status; cublasStatus_t status;
status = cublasGemmEx(context->m_handle, status = cublasGemmEx(
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, C, CUDA_R_32I, ldc, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP
m, n, k, );
alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
C, CUDA_R_32I, ldc,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
if (status != CUBLAS_STATUS_SUCCESS) if (status != CUBLAS_STATUS_SUCCESS) {
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl; std::cout << "CUBLAS ERROR: Status " << status << std::endl;
} }
} }
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, void strided_gemmex(
long long int strideA, long long int strideB, long long int strideC, int batchCount) Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
{ int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount
) {
const int falpha = 1; const int falpha = 1;
const int fbeta = 0; const int fbeta = 0;
const void * alpha = &falpha; const void* alpha = &falpha;
const void * beta = &fbeta; const void* beta = &fbeta;
cublasStatus_t status; cublasStatus_t status;
//cout << transposeA << transposeB << endl; // cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k); // printf("%i %i %i\n", m,n,k);
//printf("%i %i %i\n", lda,ldb,ldc); // printf("%i %i %i\n", lda,ldb,ldc);
//printf("%i %i %i\n", strideA, strideB, strideC); // printf("%i %i %i\n", strideA, strideB, strideC);
//printf("%i\n", batchCount); // printf("%i\n", batchCount);
status = cublasGemmStridedBatchedEx(context->m_handle,
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
m, n, k,
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
if (status != CUBLAS_STATUS_SUCCESS)
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
} status = cublasGemmStridedBatchedEx(
context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k,
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, C,
CUDA_R_32I, ldc, (long long int)strideC, batchCount, CUDA_R_32I, CUBLAS_GEMM_DEFAULT
);
int roundoff(int v, int d) { if (status != CUBLAS_STATUS_SUCCESS) {
return (v + d - 1) / d * d; std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
} }
int roundoff(int v, int d) { return (v + d - 1) / d * d; }
template<int ORDER> cublasLtOrder_t get_order() template <int ORDER> cublasLtOrder_t get_order() {
{ switch (ORDER) {
switch(ORDER)
{
case ROW: case ROW:
return CUBLASLT_ORDER_ROW; return CUBLASLT_ORDER_ROW;
break; break;
...@@ -329,11 +321,8 @@ template cublasLtOrder_t get_order<COL32>(); ...@@ -329,11 +321,8 @@ template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>(); template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>(); template cublasLtOrder_t get_order<COL_AMPERE>();
template <int ORDER> int get_leading_dim(int dim1, int dim2) {
template<int ORDER> int get_leading_dim(int dim1, int dim2) switch (ORDER) {
{
switch(ORDER)
{
case ROW: case ROW:
return dim2; return dim2;
break; break;
...@@ -342,14 +331,14 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2) ...@@ -342,14 +331,14 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
break; break;
case COL32: case COL32:
// 32*row tiles // 32*row tiles
return dim1*32; return dim1 * 32;
break; break;
case COL_TURING: case COL_TURING:
return 32*roundoff(dim1, 8); return 32 * roundoff(dim1, 8);
break; break;
case COL_AMPERE: case COL_AMPERE:
// 32*32 tiles // 32*32 tiles
return 32*roundoff(dim1, 32); return 32 * roundoff(dim1, 32);
break; break;
default: default:
return 0; return 0;
...@@ -357,15 +346,10 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2) ...@@ -357,15 +346,10 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
} }
} }
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt( template <int DTYPE_OUT, int SCALE_ROWS>
cublasLtHandle_t ltHandle, int igemmlt(
int m, int n, int k, cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
const int8_t * A, int lda, int ldb, int ldc, cudaStream_t stream
const int8_t * B,
void * C,
float * row_scale,
int lda, int ldb, int ldc,
cudaStream_t stream
) { ) {
// Calculate C = A^T @ B, in col-major layout. // Calculate C = A^T @ B, in col-major layout.
...@@ -393,17 +377,14 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt( ...@@ -393,17 +377,14 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
// Default layout order is col major // Default layout order is col major
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType)); has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); has_error |=
checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT)));
if (DTYPE_OUT == 32) { if (DTYPE_OUT == 32) {
int alpha = 1, beta = 0; int alpha = 1, beta = 0;
has_error |= checkCublasStatus(cublasLtMatmul( has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc, ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, NULL, NULL,
&alpha, A, aDesc, 0, stream
B, bDesc, &beta,
(int32_t*)C, cDesc,
(int32_t*)C, cDesc,
NULL, NULL, 0, stream
)); ));
} else { } else {
// This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.
...@@ -411,29 +392,18 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt( ...@@ -411,29 +392,18 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
if (!SCALE_ROWS) { if (!SCALE_ROWS) {
float alpha = 1.0f, beta = 0.0f; float alpha = 1.0f, beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmul( has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc, ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL,
&alpha, A, aDesc, NULL, 0, stream
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
)); ));
} else { } else {
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
float beta = 0.0f; float beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute( has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(
matmulDesc, matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, sizeof(alphaVec)
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointerMode,
sizeof(alphaVec)
)); ));
has_error |= checkCublasStatus(cublasLtMatmul( has_error |= checkCublasStatus(cublasLtMatmul(
ltHandle, matmulDesc, ltHandle, matmulDesc, row_scale, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL,
row_scale, A, aDesc, NULL, 0, stream
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
)); ));
} }
} }
...@@ -443,30 +413,33 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt( ...@@ -443,30 +413,33 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc)); has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc));
has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
if(has_error == 1) if (has_error == 1)
printf("error detected"); printf("error detected");
return has_error; return has_error;
} }
int fill_up_to_nearest_multiple(int value, int multiple) int fill_up_to_nearest_multiple(int value, int multiple) {
{
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
} }
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, cudaStream_t stream) void dequant_mm_int32_fp16(
{ int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
) {
const int threads = 512; const int threads = 512;
const int num_per_thread = 4; const int num_per_thread = 4;
const int num_per_block = threads * num_per_thread; const int num_per_block = threads * num_per_thread;
const int n = numRows*numCols; const int n = numRows * numCols;
const int num_blocks = (n + num_per_block - 1) / num_per_block; const int num_blocks = (n + num_per_block - 1) / num_per_block;
kdequant_mm_int32_fp16<num_per_thread, threads><<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n); kdequant_mm_int32_fp16<num_per_thread, threads>
<<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
) {
if (threshold == 0.0) { if (threshold == 0.0) {
kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols); kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
} else { } else {
...@@ -475,7 +448,7 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float ...@@ -475,7 +448,7 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
if (threshold == 0.0) if (threshold == 0.0)
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols); kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
else else
...@@ -483,94 +456,101 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, ...@@ -483,94 +456,101 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) void spmm_coo(
{ cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
) {
cusparseSpMatDescr_t descA; cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC; cusparseDnMatDescr_t descB, descC;
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
void *dBuffer = NULL; void* dBuffer = NULL;
size_t bufferSize = 0; size_t bufferSize = 0;
CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, CHECK_CUSPARSE(cusparseCreateCoo(
A_rowidx, A_colidx, A_vals, &descA, A_rows, A_cols, A_nnz, A_rowidx, A_colidx, A_vals, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_INDEX_32I, CUDA_R_16F
CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) ); ));
// Create dense matrix C // Create dense matrix C
CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, CHECK_CUSPARSE(cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, CUDA_R_16F, CUSPARSE_ORDER_ROW));
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
// Create dense matrix B // Create dense matrix B
if(transposed_B) if (transposed_B) {
{
int tmp = A_cols; int tmp = A_cols;
A_cols = B_cols; A_cols = B_cols;
B_cols = tmp; B_cols = tmp;
} }
CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, CHECK_CUSPARSE(cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, CUDA_R_16F, CUSPARSE_ORDER_ROW));
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
// allocate an external buffer if needed // allocate an external buffer if needed
CHECK_CUSPARSE( cusparseSpMM_bufferSize( CHECK_CUSPARSE(cusparseSpMM_bufferSize(
handle, handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize
&alpha, descA, descB, &beta, descC, CUDA_R_32F, ));
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); CUDA_CHECK_RETURN(cudaMalloc(&dBuffer, bufferSize));
CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
// execute SpMM // execute SpMM
CHECK_CUSPARSE( cusparseSpMM(handle, CHECK_CUSPARSE(cusparseSpMM(
CUSPARSE_OPERATION_NON_TRANSPOSE, handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta,
&alpha, descA, descB, &beta, descC, CUDA_R_32F, descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, dBuffer
CUSPARSE_SPMM_ALG_DEFAULT, dBuffer)); ));
// destroy matrix/vector descriptors // destroy matrix/vector descriptors
CHECK_CUSPARSE( cusparseDestroySpMat(descA) ); CHECK_CUSPARSE(cusparseDestroySpMat(descA));
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE(cusparseDestroyDnMat(descB));
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CHECK_CUSPARSE(cusparseDestroyDnMat(descC));
CUDA_CHECK_RETURN( cudaFree(dBuffer) ); CUDA_CHECK_RETURN(cudaFree(dBuffer));
} }
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) template <typename T, int BITS>
{ void spmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB
);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits) {
{
int num_blocks = (m+31)/32; int num_blocks = (m + 31) / 32;
if(bits == 32) if (bits == 32)
gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); gemm_device<T, 32, 32><<<num_blocks, 32, 0, 0>>>(m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16) if (bits == 16)
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); gemm_device<T, 16, 160><<<num_blocks, 160, 0, 0>>>(m, n, k, A, B, out, lda, ldb, ldc);
} }
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) template <typename T>
{ void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
) {
int num_blocks = (m+31)/32; int num_blocks = (m + 31) / 32;
kgemm_4bit_inference<T, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); kgemm_4bit_inference<T, 96><<<num_blocks, 96, 0, 0>>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
} }
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) template <typename T, int BITS>
{ void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, cudaStream_t stream
) {
int num_blocks = (m+3)/4; int num_blocks = (m + 3) / 4;
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); kgemm_4bit_inference_naive<T, 128, BITS>
<<<num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int FUNC> void func(T *A, T *B, T value, long n) template <typename T, int FUNC> void func(T* A, T* B, T value, long n) {
{
int threads = 512; int threads = 512;
int blocks = n/threads; int blocks = n / threads;
blocks = n % threads == 0 ? blocks : blocks + 1; blocks = n % threads == 0 ? blocks : blocks + 1;
blocks = blocks > 65535 ? 65535 : blocks; blocks = blocks > 65535 ? 65535 : blocks;
kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n); kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n);
...@@ -581,103 +561,154 @@ template <typename T, int FUNC> void func(T *A, T *B, T value, long n) ...@@ -581,103 +561,154 @@ template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================
template void func<float, FILL>(float *A, float *B, float value, long n); template void func<float, FILL>(float* A, float* B, float value, long n);
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n); template void func<unsigned char, FILL>(unsigned char* A, unsigned char* B, unsigned char value, long n);
template void func<float, ARANGE>(float *A, float *B, float value, long n); template void func<float, ARANGE>(float* A, float* B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n); template void func<float, _MUL>(float* A, float* B, float value, long n);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template void gemm_4bit_inference<half>(
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); );
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); template void gemm_4bit_inference_naive<half, 16>(
int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); int ldc, int blocksize, cudaStream_t stream
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); );
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
);
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template void gemm_4bit_inference_naive<float, 32>(
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); int ldc, int blocksize, cudaStream_t stream
);
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); // template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc,
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); // int bits);
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void gemm_host<half>(int m, int n, int k, half* A, half* B, half* out, int lda, int ldb, int ldc, int bits);
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void spmm_coo_very_sparse_naive<half, 16>(
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); );
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void spmm_coo_very_sparse_naive<signed char, 8>(
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); template int igemmlt<32, 0>(
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); int lda, int ldb, int ldc, cudaStream_t stream
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); );
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); template int igemmlt<8, 0>(
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); int lda, int ldb, int ldc, cudaStream_t stream
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); );
template int igemmlt<8, 1>(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
);
template void quantizeBlockwise<half, 1, General8bit>(
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<half, 0, General8bit>(
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<half, 0, FP4>(
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<half, 0, NF4>(
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<float, 1, General8bit>(
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<float, 0, General8bit>(
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<float, 0, FP4>(
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<float, 0, NF4>(
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const int n
);
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const int n
);
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const int n
);
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
const int n
);
template void dequantizeBlockwise<float, General8bit>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<float, FP4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<float, NF4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<half, General8bit>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<half, FP4>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<half, NF4>(
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, FP4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
template void dequantizeBlockwise<__nv_bfloat16, NF4>(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
);
#define MAKE_optimizer32bit(name, gtype) \ #define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \ template void optimizer32bit<gtype, name>( \
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ gtype * g, gtype * p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, \ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
const float eps, const float weight_decay, \ const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, \
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const int n \
);
MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(
MAKE_optimizer32bit(ADAM, __nv_bfloat16) MOMENTUM, __nv_bfloat16
MAKE_optimizer32bit(MOMENTUM, half) ) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, __nv_bfloat16) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, half) MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, float)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(RMSPROP, __nv_bfloat16)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, half)
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, float)
#define MAKE_optimizerStatic8bit(name, gtype) \ #define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ template void optimizerStatic8bit<gtype, name>( \
float *unorm, float max_unorm, float param_norm, \ gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
float beta1, float beta2, \ float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float eps, int step, float lr, \ float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
float* quantiles1, float* quantiles2, \ const float gnorm_scale, int n \
float* max1, float* max2, float* new_max1, float* new_max2, \ );
float weight_decay, \
const float gnorm_scale, int n); \
MAKE_optimizerStatic8bit(ADAM, half)
MAKE_optimizerStatic8bit(ADAM, float)
MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float)
MAKE_optimizerStatic8bit(LION, half)
MAKE_optimizerStatic8bit(LION, float)
MAKE_optimizerStatic8bit(ADAGRAD, half)
MAKE_optimizerStatic8bit(ADAGRAD, float)
MAKE_optimizerStatic8bit(ADAM, half) MAKE_optimizerStatic8bit(ADAM, float) MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit(
RMSPROP, half
) MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(LION, half) MAKE_optimizerStatic8bit(LION, float) MAKE_optimizerStatic8bit(ADAGRAD, half) MAKE_optimizerStatic8bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \ template void optimizerStatic8bitBlockwise<gtype, optim_name>( \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
);
MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
...@@ -696,8 +727,8 @@ MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); ...@@ -696,8 +727,8 @@ MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(float* g, float* gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half* g, float* gnorm_vec, int step, const int n);
template int get_leading_dim<ROW>(int dim1, int dim2); template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2); template int get_leading_dim<COL>(int dim1, int dim2);
......
...@@ -3,41 +3,41 @@ ...@@ -3,41 +3,41 @@
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#ifndef ops_H #ifndef ops_H
#define ops_H #define ops_H
#include <assert.h>
#include <cstdint> #include <cstdint>
#include <stdio.h>
#include <iostream> #include <iostream>
#include <assert.h> #include <stdio.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <cusparse.h> #include <cusparse.h>
#include <vector>
#include <functional> #include <functional>
#include <vector>
#define CUDA_CHECK_RETURN(value) \
#define CUDA_CHECK_RETURN(value) { \ { \
cudaError_t _m_cudaStat = value; \ cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \ if (_m_cudaStat != cudaSuccess) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \ fprintf(stderr, "Error %s at line %d in file %s\n", cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \ exit(1); \
} } } \
}
#define CHECK_CUSPARSE(value) { \ #define CHECK_CUSPARSE(value) \
{ \
cusparseStatus_t _m_cudaStat = value; \ cusparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \ fprintf( \
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__ \
); \
exit(1); \ exit(1); \
} } } \
}
inline void checkCudaStatus(cudaError_t status) { inline void checkCudaStatus(cudaError_t status) {
if (status != cudaSuccess) { if (status != cudaSuccess) {
...@@ -49,19 +49,17 @@ inline void checkCudaStatus(cudaError_t status) { ...@@ -49,19 +49,17 @@ inline void checkCudaStatus(cudaError_t status) {
inline int checkCublasStatus(cublasStatus_t status) { inline int checkCublasStatus(cublasStatus_t status) {
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %d\n", status); printf("cuBLAS API failed with status %d\n", status);
//throw std::logic_error("cuBLAS API failed"); // throw std::logic_error("cuBLAS API failed");
return 1; return 1;
} }
return 0; return 0;
} }
typedef enum Operations_t typedef enum Operations_t {
{
ksmul = 0, ksmul = 0,
} Operations_t; } Operations_t;
typedef enum Optimizer_t typedef enum Optimizer_t {
{
ADAM = 0, ADAM = 0,
MOMENTUM = 1, MOMENTUM = 1,
RMSPROP = 2, RMSPROP = 2,
...@@ -71,8 +69,7 @@ typedef enum Optimizer_t ...@@ -71,8 +69,7 @@ typedef enum Optimizer_t
ADEMAMIX = 6 ADEMAMIX = 6
} Optimizer_t; } Optimizer_t;
typedef enum Transform_t typedef enum Transform_t {
{
ROW = 0, ROW = 0,
COL = 1, COL = 1,
COL32 = 2, COL32 = 2,
...@@ -80,109 +77,135 @@ typedef enum Transform_t ...@@ -80,109 +77,135 @@ typedef enum Transform_t
COL_AMPERE = 4, COL_AMPERE = 4,
} Transform_t; } Transform_t;
typedef enum DataType_t typedef enum DataType_t {
{
General8bit = 0, General8bit = 0,
FP4 = 1, FP4 = 1,
NF4 = 2, NF4 = 2,
} DataType_t; } DataType_t;
typedef enum Funcs_t typedef enum Funcs_t {
{
FILL = 0, FILL = 0,
ARANGE = 1, ARANGE = 1,
_MUL = 2, _MUL = 2,
} Funcs_t; } Funcs_t;
class Context class Context {
{
public: public:
cublasHandle_t m_handle; cublasHandle_t m_handle;
Context() Context() {
{
cublasHandle_t handle; cublasHandle_t handle;
cublasCreate_v2(&handle); cublasCreate_v2(&handle);
m_handle = handle; m_handle = handle;
} }
}; };
class ContextLt class ContextLt {
{
public: public:
cublasLtHandle_t m_handle; cublasLtHandle_t m_handle;
ContextLt() ContextLt() {
{
cublasLtHandle_t handle; cublasLtHandle_t handle;
cublasLtCreate(&handle); cublasLtCreate(&handle);
m_handle = handle; m_handle = handle;
} }
}; };
class ContextCusparse class ContextCusparse {
{
public: public:
cusparseHandle_t m_handle; cusparseHandle_t m_handle;
ContextCusparse() ContextCusparse() {
{
cusparseHandle_t handle; cusparseHandle_t handle;
cusparseCreate(&handle); cusparseCreate(&handle);
m_handle = handle; m_handle = handle;
} }
}; };
void quantize(float *code, float *A, unsigned char *out, int n); void quantize(float* code, float* A, unsigned char* out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream); void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream);
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template <typename T, int STOCHASTIC, int DATA_TYPE>
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream); void quantizeBlockwise(
float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, );
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, template <typename T, int DATA_TYPE>
float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, void dequantizeBlockwise(
int step, float lr, const float gnorm_scale, bool skip_zeros, int n); float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, cudaStream_t stream
);
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm, template <typename T, int OPTIMIZER>
float beta1, float beta2, void optimizer32bit(
float eps, int step, float lr, T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2,
float* quantiles1, float* quantiles2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale,
float* max1, float* max2, float* new_max1, float* new_max2, bool skip_zeros, int n
float weight_decay, );
const float gnorm_scale, int n);
template <typename T, int OPTIMIZER>
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, void optimizerStatic8bit(
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1,
bool skip_zeros, int n); float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n
);
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
template <typename T, int OPTIMIZER>
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void optimizerStatic8bitBlockwise(
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
long long int strideA, long long int strideB, long long int strideC, int batchCount); float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, bool skip_zeros, int n
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); );
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); template <typename T> void percentileClipping(T* g, float* gnorm_vec, int step, const int n);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); void gemmex(
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); );
void strided_gemmex(
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); );
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template <int DTYPE_OUT, int SCALE_ROWS>
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); int igemmlt(
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
template <typename T, int FUNC> void func(T *A, T *B, T value, long n); );
void cutlass_igemm(
bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc
);
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
);
void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
);
void spmm_coo(
cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
);
template <typename T, int BITS>
void spmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);
void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits);
template <typename T>
void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
);
template <typename T, int BITS>
void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, cudaStream_t stream
);
template <typename T, int FUNC> void func(T* A, T* B, T value, long n);
#endif #endif
...@@ -20,39 +20,60 @@ ...@@ -20,39 +20,60 @@
#if BUILD_CUDA #if BUILD_CUDA
//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) // void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); } //{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) void gemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) {
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); } gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16);
}
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) void gemm_4bit_inference(
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
) {
gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void gemm_4bit_inference_naive_fp16(
{ gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void gemm_4bit_inference_naive_bf16(
{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<__nv_bfloat16, 16>(
m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream
);
}
void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void gemm_4bit_inference_naive_fp32(
{ gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \ void fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { func<ctype, FUNC>(A, B, value, n); }
MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
#define MAKE_FUNC32(fname, oname, gtype, gbits) \ #define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ void fname##32bit_grad_##gbits( \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, \ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
const float eps, const float weight_decay, \ const float weight_decay, const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n \
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ ) { \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ optimizer32bit<gtype, oname>( \
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \
lr, gnorm_scale, skip_zeros, n \
); \
}
MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(momentum, MOMENTUM, half, 16)
...@@ -70,19 +91,18 @@ MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32) ...@@ -70,19 +91,18 @@ MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32)
MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16) MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16)
MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16) MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
#define MAKE_FUNC8(fname, oname, gtype, gbits) \ #define MAKE_FUNC8(fname, oname, gtype, gbits) \
void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ void fname##_static_8bit_grad_##gbits( \
float *unorm, float max_unorm, float param_norm, \ gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
float beta1, float beta2, \ float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float eps, int step, float lr, \ float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
float* quantiles1, float* quantiles2, \ float gnorm_scale, int n \
float* max1, float* max2, float* new_max1, float* new_max2, \ ) { \
float weight_decay, float gnorm_scale, int n) \ optimizerStatic8bit<gtype, oname>( \
{ \ g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \
optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ ); \
} \ }
MAKE_FUNC8(adam, ADAM, float, 32) MAKE_FUNC8(adam, ADAM, float, 32)
MAKE_FUNC8(adam, ADAM, half, 16) MAKE_FUNC8(adam, ADAM, half, 16)
...@@ -94,10 +114,16 @@ MAKE_FUNC8(lion, LION, float, 32) ...@@ -94,10 +114,16 @@ MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16) MAKE_FUNC8(lion, LION, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ void fname##_8bit_blockwise_grad_##gbits( \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
) { \
optimizerStatic8bitBlockwise<gtype, optim_name>( \
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \
weight_decay, gnorm_scale, skip_zeros, n \
); \
}
MAKE_BLOCKWISE8(adam, ADAM, half, fp16) MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
...@@ -118,239 +144,511 @@ MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16) ...@@ -118,239 +144,511 @@ MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
void percentileClipping_g32(float* g, float* gnorm_vec, int step, const int n) {
percentileClipping<float>(g, gnorm_vec, step, n);
}
void percentileClipping_g16(half* g, float* gnorm_vec, int step, const int n) {
percentileClipping<half>(g, gnorm_vec, step, n);
}
void quantizeBlockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16_fp4(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_bf16_nf4(
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
}
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); } void quantizeBlockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); } quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n);
}
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_fp4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } }
void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_nf4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n);
void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } }
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } void dequantizeBlockwise_fp16(
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } ) {
dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n, stream); } \ void dequantizeBlockwise_fp16_fp4(
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n, stream); } \ float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n, stream); } \ ) {
dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp16_nf4(
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n, stream); } float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n, stream); } ) {
dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32(
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } ) {
dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { void dequantizeBlockwise_fp32_fp4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp32_nf4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16_fp4(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16_nf4(
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
) {
dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
int igemmlt_32(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
int igemmlt_8(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
int igemmlt_8_rowscale(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, cudaStream_t stream
) {
return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) void spmm_coo_very_sparse_naive_fp16(
{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive<half, 16>(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) void spmm_coo_very_sparse_naive_int8(
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive<signed char, 8>(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
#endif #endif
extern "C" extern "C" {
{
#if BUILD_CUDA #if BUILD_CUDA
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); }
void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); }
void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); } dequantize(code, A, out, n, stream);
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); } }
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_fp16_fp4(
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } ) {
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream);
}
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp16(
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
) {
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); } dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream);
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } }
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_fp16_nf4(
void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } ) {
void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); }
void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } void cquantize_blockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n);
}
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ void cquantize_blockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n);
const float beta1, const float beta2, const float beta3, const float alpha, \ }
const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ void cquantize_blockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n);
}
MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, fp16) void cquantize_blockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) {
MAKE_CFUNC32(adam, __nv_bfloat16, bf16) quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n);
MAKE_CFUNC32(momentum, float, 32) }
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32) void cquantize_blockwise_fp32_fp4(
MAKE_CFUNC32(rmsprop, half, 16) float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n
MAKE_CFUNC32(lion, float, fp32) ) {
MAKE_CFUNC32(lion, half, fp16) quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n);
MAKE_CFUNC32(lion, __nv_bfloat16, bf16) }
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16) void cquantize_blockwise_fp32_nf4(
MAKE_CFUNC32(ademamix, float, fp32) float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n
MAKE_CFUNC32(ademamix, half, fp16) ) {
MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n);
}
#define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ void cdequantize_blockwise_fp32(
float *unorm, float max_unorm, float param_norm, \ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
float beta1, float beta2, \ ) {
float eps, int step, float lr, \ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream);
float* quantiles1, float* quantiles2, \ }
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \ void cdequantize_blockwise_fp32_fp4(
{ \ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ ) {
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream);
} \ }
MAKE_CFUNC8(adam, float, 32) void cdequantize_blockwise_fp32_nf4(
MAKE_CFUNC8(adam, half, 16) float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream
MAKE_CFUNC8(momentum, float, 32) ) {
MAKE_CFUNC8(momentum, half, 16) dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream);
MAKE_CFUNC8(rmsprop, float, 32) }
MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32) void cquantize_blockwise_bf16(
MAKE_CFUNC8(lion, half, 16) float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n);
void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ }
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ void cquantize_blockwise_bf16_fp4(
{ fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
) {
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n);
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) }
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) void cquantize_blockwise_bf16_nf4(
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) ) {
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n);
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) }
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) void cdequantize_blockwise_bf16(
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) ) {
MAKE_CBLOCKWISE8(lion, LION, half, fp16) dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream);
MAKE_CBLOCKWISE8(lion, LION, float, fp32) }
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16) void cdequantize_blockwise_bf16_fp4(
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32) float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) ) {
dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream);
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void cdequantize_blockwise_bf16_nf4(
void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream
{ gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); } ) {
void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream);
long strideA, long strideB, long strideC, int batchCount) }
{ strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }
#define MAKE_CFUNC32(name, gtype, gbits) \
Context *get_context(){ return new Context(); } void c##name##32bit_grad_##gbits( \
ContextCusparse *get_cusparse(){ return new ContextCusparse(); } gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \
int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, \
return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); const int n \
) { \
name##32bit_grad_##gbits( \
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \
lr, gnorm_scale, skip_zeros, n \
); \
} }
int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, fp16)
MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
MAKE_CFUNC32(momentum, float, 32)
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, fp32)
MAKE_CFUNC32(lion, half, fp16)
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)
MAKE_CFUNC32(ademamix, float, fp32)
MAKE_CFUNC32(ademamix, half, fp16)
MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16)
#define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_grad_##gbits( \
gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \
float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \
float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \
float gnorm_scale, int n \
) { \
name##_static_8bit_grad_##gbits( \
g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \
max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \
); \
} }
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); MAKE_CFUNC8(adam, float, 32)
MAKE_CFUNC8(adam, half, 16)
MAKE_CFUNC8(momentum, float, 32)
MAKE_CFUNC8(momentum, half, 16)
MAKE_CFUNC8(rmsprop, float, 32)
MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32)
MAKE_CFUNC8(lion, half, 16)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_grad_##gbits( \
gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \
float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \
float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \
) { \
fname##_8bit_blockwise_grad_##gbits( \
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \
weight_decay, gnorm_scale, skip_zeros, n \
); \
} }
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
void cpercentile_clipping_g32(float* g, float* gnorm_vec, int step, const int n) {
percentileClipping_g32(g, gnorm_vec, step, n);
}
void cpercentile_clipping_g16(half* g, float* gnorm_vec, int step, const int n) {
percentileClipping_g16(g, gnorm_vec, step, n);
}
void cigemm(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
) {
gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc);
}
void cbatched_igemm(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc, long strideA, long strideB, long strideC, int batchCount
) {
strided_gemmex(
context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount
);
}
Context* get_context() { return new Context(); }
ContextCusparse* get_cusparse() { return new ContextCusparse(); }
int cigemmlt_32(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_32((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int cigemmlt_8(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_8((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
int cigemmlt_8_rowscale(
Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda,
int ldb, int ldc, cudaStream_t stream
) {
return igemmlt_8_rowscale((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
void cdequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
) {
dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream);
}
void cget_row_stats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
getRowStats(A, rowStats, threshold, rows, cols, stream); getRowStats(A, rowStats, threshold, rows, cols, stream);
} }
void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
void cint8_vector_quant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
) {
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
} }
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) void cspmm_coo(
{ spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } ContextCusparse* context, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
) {
spmm_coo(
(cusparseHandle_t)context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C,
transposed_B
);
}
void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) void cspmm_coo_very_sparse_naive_fp16(
{ spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive_fp16(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) void cspmm_coo_very_sparse_naive_int8(
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
) {
spmm_coo_very_sparse_naive_int8(
max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB,
colsB
);
}
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) // void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) void cgemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) {
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc);
}
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) void cgemm_4bit_inference(
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
) {
gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
void *cget_managed_ptr(size_t bytes) void* cget_managed_ptr(size_t bytes) {
{ void* ptr;
void *ptr;
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
return ptr; return ptr;
} }
void cprefetch(void *ptr, size_t bytes, int device) void cprefetch(void* ptr, size_t bytes, int device) {
{
int hasPrefetch = 0; int hasPrefetch = 0;
CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead CUDA_CHECK_RETURN(
if (hasPrefetch == 0) return; cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)
); // 40ns overhead
if (hasPrefetch == 0)
return;
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ void c##fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { fname##_##type_name(A, B, value, n); }
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void cgemm_4bit_inference_naive_fp16(
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void cgemm_4bit_inference_naive_bf16(
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) void cgemm_4bit_inference_naive_fp32(
{ gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
) {
gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
#endif #endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_cpu_fp32(
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n
) {
quantize_cpu(code, A, absmax, out, blocksize, n);
}
void cdequantize_blockwise_cpu_fp32(
float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n
) {
dequantize_cpu(code, A, absmax, out, blocksize, n);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment