Unverified Commit fe46dac2 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Add lint action (#32)

* temp

* fix lint

* csrc->src

* remove clang-format

* skip .rst

* skip doc

* clang-format

version

version

* mat_B
parent e8ab4ba3
...@@ -509,10 +509,10 @@ void cublasINT8MMWrapper::SpGemm( ...@@ -509,10 +509,10 @@ void cublasINT8MMWrapper::SpGemm(
} }
else { else {
// initializing MatDesc takes a lot of time // initializing MatDesc takes a lot of time
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
sp_mat_A_desc_map_[mark] = matA; sp_mat_A_desc_map_[mark] = mat_A;
sp_mat_B_desc_map_[mark] = matB; sp_mat_B_desc_map_[mark] = mat_B;
sp_mat_C_desc_map_[mark] = matC; sp_mat_C_desc_map_[mark] = mat_C;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&sp_mat_A_desc_map_[mark], &sp_mat_A_desc_map_[mark],
num_A_rows, num_A_rows,
......
...@@ -695,10 +695,10 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa, ...@@ -695,10 +695,10 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa,
} }
else { else {
// initializing MatDesc takes a lot of time // initializing MatDesc takes a lot of time
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
sp_mat_A_desc_map_[mark] = matA; sp_mat_A_desc_map_[mark] = mat_A;
sp_mat_B_desc_map_[mark] = matB; sp_mat_B_desc_map_[mark] = mat_B;
sp_mat_C_desc_map_[mark] = matC; sp_mat_C_desc_map_[mark] = mat_C;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&sp_mat_A_desc_map_[mark], &sp_mat_A_desc_map_[mark],
num_A_rows, num_A_rows,
...@@ -752,9 +752,9 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) ...@@ -752,9 +752,9 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k)
int num_A_cols = k; int num_A_cols = k;
int lda = num_A_rows; int lda = num_A_rows;
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&matA, &mat_A,
num_A_rows, num_A_rows,
num_A_cols, num_A_cols,
lda, lda,
...@@ -763,7 +763,7 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) ...@@ -763,7 +763,7 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k)
order, order,
CUSPARSELT_SPARSITY_50_PERCENT)); CUSPARSELT_SPARSITY_50_PERCENT));
size_t compressed_size = 0; size_t compressed_size = 0;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &matA, &compressed_size)); CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &mat_A, &compressed_size));
return compressed_size; return compressed_size;
} }
...@@ -771,11 +771,11 @@ void cublasMMWrapper::compressMatrix(const void* input, void* output, const int ...@@ -771,11 +771,11 @@ void cublasMMWrapper::compressMatrix(const void* input, void* output, const int
{ {
cusparseOrder_t order = CUSPARSE_ORDER_COL; cusparseOrder_t order = CUSPARSE_ORDER_COL;
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
unsigned alignment = 16; unsigned alignment = 16;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&cusparselt_handle_, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &cusparselt_handle_, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &matA, true, opA, input, output, stream_)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &mat_A, true, opA, input, output, stream_))
sync_check_cuda_error(); sync_check_cuda_error();
} }
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
namespace fastertransformer { namespace fastertransformer {
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val; float2 f_val;
f_val.x = __low2float(val); f_val.x = __low2float(val);
...@@ -33,26 +34,34 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { ...@@ -33,26 +34,34 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#endif #endif
} }
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val; float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f); f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f); f_val.y = max(min(__high2float(val), 127.f), -128.f);
union { int8_t int8[2]; int16_t int16; }; union {
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x)); int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y)); int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16; return int16;
#else #else
val = __hmin2(val, make_bfloat162(127., 127.)); val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.)); val = __hmax2(val, make_bfloat162(-128., -128.));
union { int8_t int8[2]; int16_t int16; }; union {
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x)); int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y)); int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16; return int16;
#endif #endif
} }
inline __device__ __nv_bfloat162 float22bf162(const float2 val) { inline __device__ __nv_bfloat162 float22bf162(const float2 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y); return __floats2bfloat162_rn(val.x, val.y);
#else #else
...@@ -60,7 +69,8 @@ inline __device__ __nv_bfloat162 float22bf162(const float2 val) { ...@@ -60,7 +69,8 @@ inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
#endif #endif
} }
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2; __nv_bfloat162 val2;
val2.x = val; val2.x = val;
...@@ -71,7 +81,8 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { ...@@ -71,7 +81,8 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh; float fxl, fxh, fyl, fyh;
fxl = __low2float(x); fxl = __low2float(x);
...@@ -84,15 +95,17 @@ inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bf ...@@ -84,15 +95,17 @@ inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
#else #else
return __hadd(x, y); return __hadd(x, y);
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh; float fxl, fxh, fyl, fyh;
fxl = __low2float(x); fxl = __low2float(x);
...@@ -105,15 +118,17 @@ inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bf ...@@ -105,15 +118,17 @@ inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
#else #else
return __hsub(x, y); return __hsub(x, y);
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh; float fxl, fxh, fyl, fyh;
fxl = __low2float(x); fxl = __low2float(x);
...@@ -126,15 +141,17 @@ inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bf ...@@ -126,15 +141,17 @@ inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else #else
return __hmul(x, y); return __hmul(x, y);
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh; float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x); fxl = __low2float(x);
...@@ -149,19 +166,22 @@ inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bf ...@@ -149,19 +166,22 @@ inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else #else
return __hfma(x, y, z); return __hfma(x, y, z);
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh; float fxl, fxh;
fxl = __low2float(x); fxl = __low2float(x);
fxh = __high2float(x);; fxh = __high2float(x);
;
return __floats2bfloat162_rn(expf(fxl), expf(fxh)); return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else #else
return h2exp(x); return h2exp(x);
...@@ -169,17 +189,27 @@ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { ...@@ -169,17 +189,27 @@ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
} }
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; {
return bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return bf16hadd2(x, y);
};
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{ {
__nv_bfloat162 t; t.x = x; t.y = y; return t; __nv_bfloat162 t;
t.x = x;
t.y = y;
return t;
} }
#endif #endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else #else
...@@ -187,7 +217,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_ ...@@ -187,7 +217,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else #else
...@@ -195,7 +226,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_ ...@@ -195,7 +226,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch; float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a); fal = __low2float(a);
...@@ -210,7 +242,8 @@ inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, _ ...@@ -210,7 +242,8 @@ inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, _
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else #else
...@@ -218,7 +251,8 @@ inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_ ...@@ -218,7 +251,8 @@ inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch; float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a); fal = __low2float(a);
...@@ -233,7 +267,8 @@ inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, _ ...@@ -233,7 +267,8 @@ inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, _
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a); fal = __low2float(a);
......
...@@ -462,29 +462,29 @@ void generate_encoder_gemm_config( ...@@ -462,29 +462,29 @@ void generate_encoder_gemm_config(
T* d_C = d_B + k * n * batchCount[i]; T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed; T* dA_compressed;
{ {
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size; size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size)) CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size)); check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
} }
float exec_time = 99999.0f; float exec_time = 99999.0f;
int fast_algo = 0; int fast_algo = 0;
for (int alg = 0; alg < 4; ++alg) { for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr; void* d_workspace = nullptr;
int num_streams = 1; int num_streams = 1;
cudaStream_t streams[1] = {stream}; cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) { for (int ite = 0; ite < ites; ++ite) {
// initializing MatDesc takes a lot of time // initializing MatDesc takes a lot of time
...@@ -494,7 +494,7 @@ void generate_encoder_gemm_config( ...@@ -494,7 +494,7 @@ void generate_encoder_gemm_config(
cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan; cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
...@@ -1239,15 +1239,15 @@ int generate_encoder_igemm_config( ...@@ -1239,15 +1239,15 @@ int generate_encoder_igemm_config(
int8_t* d_C = d_B + k * n; int8_t* d_C = d_B + k * n;
int8_t* dA_compressed; int8_t* dA_compressed;
{ {
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size; size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size)) CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size)); check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError(); cudaError_t result = cudaGetLastError();
...@@ -1259,14 +1259,14 @@ int generate_encoder_igemm_config( ...@@ -1259,14 +1259,14 @@ int generate_encoder_igemm_config(
int fast_algo = 0; int fast_algo = 0;
for (int alg = 0; alg < 4; ++alg) { for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr; void* d_workspace = nullptr;
int num_streams = 1; int num_streams = 1;
cudaStream_t streams[1] = {stream}; cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_8I, col_order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_8I, col_order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_8I, col_order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_8I, col_order))
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) { for (int ite = 0; ite < ites; ++ite) {
// initializing MatDesc takes a lot of time // initializing MatDesc takes a lot of time
...@@ -1276,7 +1276,7 @@ int generate_encoder_igemm_config( ...@@ -1276,7 +1276,7 @@ int generate_encoder_igemm_config(
cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan; cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
...@@ -617,15 +617,15 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -617,15 +617,15 @@ void generate_gpt_gemm_config(int batch_size,
T* d_C = d_B + k * n * batchCount[i]; T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed; T* dA_compressed;
{ {
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size; size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size)) CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size)); check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
} }
float exec_time = 99999.0f; float exec_time = 99999.0f;
...@@ -633,14 +633,15 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -633,14 +633,15 @@ void generate_gpt_gemm_config(int batch_size,
if (isSparseGemmAvailable(m, n, k)) { if (isSparseGemmAvailable(m, n, k)) {
for (int alg = 0; alg < 4; ++alg) { for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr; void* d_workspace = nullptr;
int num_streams = 1; int num_streams = 1;
cudaStream_t streams[1] = {stream}; cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(
cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
cudaDeviceSynchronize(); cudaDeviceSynchronize();
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) { for (int ite = 0; ite < ites; ++ite) {
...@@ -651,7 +652,7 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -651,7 +652,7 @@ void generate_gpt_gemm_config(int batch_size,
cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan; cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
...@@ -616,15 +616,15 @@ void generate_t5_gemm_config(int batch_size, ...@@ -616,15 +616,15 @@ void generate_t5_gemm_config(int batch_size,
T* d_C = d_B + k * n * batchCount[i]; T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed; T* dA_compressed;
{ {
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size; size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size)) CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size)); check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
} }
float exec_time = 99999.0f; float exec_time = 99999.0f;
...@@ -632,14 +632,15 @@ void generate_t5_gemm_config(int batch_size, ...@@ -632,14 +632,15 @@ void generate_t5_gemm_config(int batch_size,
if (isSparseGemmAvailable(m, n, k)) { if (isSparseGemmAvailable(m, n, k)) {
for (int alg = 0; alg < 4; ++alg) { for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr; void* d_workspace = nullptr;
int num_streams = 1; int num_streams = 1;
cudaStream_t streams[1] = {stream}; cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(
cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
cudaDeviceSynchronize(); cudaDeviceSynchronize();
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) { for (int ite = 0; ite < ites; ++ite) {
...@@ -650,7 +651,7 @@ void generate_t5_gemm_config(int batch_size, ...@@ -650,7 +651,7 @@ void generate_t5_gemm_config(int batch_size,
cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan; cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
...@@ -27,8 +27,7 @@ namespace fastertransformer { ...@@ -27,8 +27,7 @@ namespace fastertransformer {
class Logger { class Logger {
public: public:
enum Level enum Level {
{
TRACE = 0, TRACE = 0,
DEBUG = 10, DEBUG = 10,
INFO = 20, INFO = 20,
......
import torch # flake8: noqa
import unittest import unittest
import torch
def random_tensor(shape, dtype, device, mean=0, std=1): def random_tensor(shape, dtype, device, mean=0, std=1):
return torch.empty(shape, dtype=dtype, device=device).normal_(mean, std) return torch.empty(shape, dtype=dtype, device=device).normal_(mean, std)
class TestGemmDequantize(unittest.TestCase): class TestGemmDequantize(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
torch.classes.load_library("lib/libth_transformer.so") torch.classes.load_library('lib/libth_transformer.so')
torch.classes.load_library("lib/libgemm_dq_unit_ops.so") torch.classes.load_library('lib/libgemm_dq_unit_ops.so')
self.unpack_packed_int4s = torch.ops.fastertransformer.unpack_int4_packed_tensor_to_int8 self.unpack_packed_int4s = torch.ops.fastertransformer.unpack_int4_packed_tensor_to_int8
self.pack_int4s = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4 self.pack_int4s = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
self.fused_gemm_dq = torch.ops.gemm_dq_unit_ops.fused_gemm_dq self.fused_gemm_dq = torch.ops.gemm_dq_unit_ops.fused_gemm_dq
...@@ -26,17 +31,26 @@ class TestGemmDequantize(unittest.TestCase): ...@@ -26,17 +31,26 @@ class TestGemmDequantize(unittest.TestCase):
upper_bound = 127 if quant_type == torch.int8 else 7 upper_bound = 127 if quant_type == torch.int8 else 7
m, n, k = 64, 128, 64 m, n, k = 64, 128, 64
weights = torch.randint(lower_bound, upper_bound, [k, n], dtype=torch.int8, device="cpu") weights = torch.randint(lower_bound,
upper_bound, [k, n],
packed_weight = self.pack_int4s(weights) if quant_type == torch.quint4x2 else weights dtype=torch.int8,
cuda_weights = self.preprocess_weights_for_mixed_gemm(packed_weight, quant_type).to("cuda") device='cpu')
weights = weights.to("cuda")
packed_weight = self.pack_int4s(
act = torch.eye(m, dtype=weight_type, device="cuda") weights) if quant_type == torch.quint4x2 else weights
cuda_weights = self.preprocess_weights_for_mixed_gemm(
packed_weight, quant_type).to('cuda')
weights = weights.to('cuda')
act = torch.eye(m, dtype=weight_type, device='cuda')
scales = torch.ones([n], dtype=weight_type, device='cuda') scales = torch.ones([n], dtype=weight_type, device='cuda')
actual = self.fused_gemm_dq(act, cuda_weights, scales) actual = self.fused_gemm_dq(act, cuda_weights, scales)
torch.testing.assert_close(actual, weights, atol=0, rtol=0, check_dtype=False) torch.testing.assert_close(actual,
weights,
atol=0,
rtol=0,
check_dtype=False)
def test_fp16_int8_dequantize(self): def test_fp16_int8_dequantize(self):
self.dequantize_test_helper(torch.float16, torch.int8) self.dequantize_test_helper(torch.float16, torch.int8)
...@@ -51,144 +65,207 @@ class TestGemmDequantize(unittest.TestCase): ...@@ -51,144 +65,207 @@ class TestGemmDequantize(unittest.TestCase):
self.dequantize_test_helper(torch.bfloat16, torch.quint4x2) self.dequantize_test_helper(torch.bfloat16, torch.quint4x2)
def apply_act(self, inp, act_str): def apply_act(self, inp, act_str):
if act_str == "identity": if act_str == 'identity':
return inp return inp
elif act_str == "silu": elif act_str == 'silu':
return torch.nn.SiLU()(inp) return torch.nn.SiLU()(inp)
elif act_str == "relu": elif act_str == 'relu':
return torch.nn.ReLU()(inp) return torch.nn.ReLU()(inp)
elif act_str == "gelu": elif act_str == 'gelu':
return torch.nn.GELU(approximate="tanh")(inp) return torch.nn.GELU(approximate='tanh')(inp)
else: else:
assert False, "Unsupported activation" assert False, 'Unsupported activation'
def gemm_dequant_test_helper(self, compute_type, weight_dtype, gemm_ms, gemm_ns, gemm_ks, rtol, atol, act_str="only_gemm", benchmark=False): def gemm_dequant_test_helper(self,
assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, "Weight must be quantized" compute_type,
weight_dtype,
gemm_ms,
gemm_ns,
gemm_ks,
rtol,
atol,
act_str='only_gemm',
benchmark=False):
assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, 'Weight must be quantized'
for gemm_k in gemm_ks: for gemm_k in gemm_ks:
for gemm_n in gemm_ns: for gemm_n in gemm_ns:
torch_weights_cpu = random_tensor((gemm_k, gemm_n), dtype=compute_type, device="cpu", mean=0, std=0.002) torch_weights_cpu = random_tensor((gemm_k, gemm_n),
ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer(torch_weights_cpu, weight_dtype) dtype=compute_type,
ref_torch_weights = self.unpack_packed_int4s(ref_torch_weights) if weight_dtype == torch.quint4x2 else ref_torch_weights device='cpu',
ref_torch_weights = ref_torch_weights.to("cuda") mean=0,
processed_torch_weights = processed_torch_weights.to("cuda") std=0.002)
torch_weight_scales = torch_weight_scales.to("cuda") ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer(
torch_biases = random_tensor((gemm_n), dtype=compute_type, device="cuda", mean=0, std=0.1) torch_weights_cpu, weight_dtype)
ref_torch_weights = self.unpack_packed_int4s(
ref_torch_weights
) if weight_dtype == torch.quint4x2 else ref_torch_weights
ref_torch_weights = ref_torch_weights.to('cuda')
processed_torch_weights = processed_torch_weights.to('cuda')
torch_weight_scales = torch_weight_scales.to('cuda')
torch_biases = random_tensor((gemm_n),
dtype=compute_type,
device='cuda',
mean=0,
std=0.1)
for num_rows in gemm_ms: for num_rows in gemm_ms:
torch_activations = torch.randn(size=(num_rows, gemm_k), dtype=compute_type, device="cuda") torch_activations = torch.randn(size=(num_rows, gemm_k),
dtype=compute_type,
device='cuda')
scales_unsqueezed = torch_weight_scales.unsqueeze(0) scales_unsqueezed = torch_weight_scales.unsqueeze(0)
casted_weights = ref_torch_weights.to(torch_activations.dtype) casted_weights = ref_torch_weights.to(
dequantized_weights = torch.multiply(casted_weights, scales_unsqueezed) torch_activations.dtype)
dequantized_weights = torch.multiply(
casted_weights, scales_unsqueezed)
if benchmark: if benchmark:
assert act_str == "only_gemm", "Benchmarks against cublas must use just GEMM." assert act_str == 'only_gemm', 'Benchmarks against cublas must use just GEMM.'
torch.cuda.profiler.start() torch.cuda.profiler.start()
times, results = self.bench(torch_activations, processed_torch_weights, torch_weight_scales, dequantized_weights, 200) times, results = self.bench(torch_activations,
processed_torch_weights,
torch_weight_scales,
dequantized_weights, 200)
torch.cuda.profiler.stop() torch.cuda.profiler.stop()
times = times[0] times = times[0]
cublas_time = times[0].item() cublas_time = times[0].item()
ft_time = times[1].item() ft_time = times[1].item()
ft_speedup = cublas_time / ft_time ft_speedup = cublas_time / ft_time
print("{},{},{},{},{},{}".format(num_rows, gemm_n, gemm_k, cublas_time, ft_time, ft_speedup)) print('{},{},{},{},{},{}'.format(
num_rows, gemm_n, gemm_k, cublas_time, ft_time,
ft_speedup))
reference_result = results[0] reference_result = results[0]
ft_result = results[1] ft_result = results[1]
else: else:
if act_str == "only_gemm": if act_str == 'only_gemm':
reference_result = torch.matmul(torch_activations, dequantized_weights) reference_result = torch.matmul(
ft_result = self.fused_gemm_dq(torch_activations, processed_torch_weights, torch_weight_scales) torch_activations, dequantized_weights)
ft_result = self.fused_gemm_dq(
torch_activations, processed_torch_weights,
torch_weight_scales)
else: else:
reference_result = torch.matmul(torch_activations, dequantized_weights) reference_result = torch.matmul(
torch_activations, dequantized_weights)
reference_result += torch_biases.unsqueeze(0) reference_result += torch_biases.unsqueeze(0)
reference_result = self.apply_act(reference_result, act_str) reference_result = self.apply_act(
reference_result, act_str)
ft_result = self.fused_gemm_dq_bias_act(torch_activations, processed_torch_weights, torch_weight_scales, torch_biases, act_str)
ft_result = self.fused_gemm_dq_bias_act(
msg = "FC1 Failed on m={}, n={}, k={}".format(num_rows, gemm_n, gemm_k) torch_activations, processed_torch_weights,
torch.testing.assert_close(ft_result, reference_result, rtol=rtol, atol=atol, msg=msg, check_dtype=False) torch_weight_scales, torch_biases, act_str)
msg = 'FC1 Failed on m={}, n={}, k={}'.format(
num_rows, gemm_n, gemm_k)
torch.testing.assert_close(ft_result,
reference_result,
rtol=rtol,
atol=atol,
msg=msg,
check_dtype=False)
def test_fp16_int8_gemm(self): def test_fp16_int8_gemm(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1], torch.float16,
gemm_ns = [1024, 2048, 4096], torch.int8,
gemm_ks = [4096, 8192, 16384], gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
rtol=0.001, atol=0.002) gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.001,
atol=0.002)
def test_fp16_int4_gemm(self): def test_fp16_int4_gemm(self):
self.gemm_dequant_test_helper(torch.float16, torch.quint4x2, self.gemm_dequant_test_helper(
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1], torch.float16,
gemm_ns = [1024, 2048, 4096], torch.quint4x2,
gemm_ks = [4096, 8192, 16384], gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
rtol=0.001, atol=0.002) gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.001,
atol=0.002)
def test_bf16_int8_gemm(self): def test_bf16_int8_gemm(self):
self.gemm_dequant_test_helper(torch.bfloat16, torch.int8, self.gemm_dequant_test_helper(
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1], torch.bfloat16,
gemm_ns = [1024, 2048, 4096], torch.int8,
gemm_ks = [4096, 8192, 16384], gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
rtol=0.01, atol=0.01) gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.01,
atol=0.01)
def test_bf16_int4_gemm(self): def test_bf16_int4_gemm(self):
self.gemm_dequant_test_helper(torch.bfloat16, torch.quint4x2, self.gemm_dequant_test_helper(
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1], torch.bfloat16,
gemm_ns = [1024, 2048, 4096], torch.quint4x2,
gemm_ks = [4096, 8192, 16384], gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
rtol=0.01, atol=0.01) gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.01,
atol=0.01)
def test_fp16_int8_gemm_bias(self): def test_fp16_int8_gemm_bias(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(torch.float16,
gemm_ms = [256], torch.int8,
gemm_ns = [1024], gemm_ms=[256],
gemm_ks = [8192], gemm_ns=[1024],
rtol=0.001, atol=0.002, gemm_ks=[8192],
act_str="identity") rtol=0.001,
atol=0.002,
act_str='identity')
def test_fp16_int8_gemm_bias_relu(self): def test_fp16_int8_gemm_bias_relu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(torch.float16,
gemm_ms = [256], torch.int8,
gemm_ns = [1024], gemm_ms=[256],
gemm_ks = [8192], gemm_ns=[1024],
rtol=0.001, atol=0.002, gemm_ks=[8192],
act_str="relu") rtol=0.001,
atol=0.002,
act_str='relu')
def test_fp16_int8_gemm_bias_gelu(self): def test_fp16_int8_gemm_bias_gelu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(torch.float16,
gemm_ms = [256], torch.int8,
gemm_ns = [1024], gemm_ms=[256],
gemm_ks = [8192], gemm_ns=[1024],
rtol=0.001, atol=0.002, gemm_ks=[8192],
act_str="gelu") rtol=0.001,
atol=0.002,
act_str='gelu')
def test_fp16_int8_gemm_bias_silu(self): def test_fp16_int8_gemm_bias_silu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(torch.float16,
gemm_ms = [256], torch.int8,
gemm_ns = [1024], gemm_ms=[256],
gemm_ks = [8192], gemm_ns=[1024],
rtol=0.001, atol=0.002, gemm_ks=[8192],
act_str="silu") rtol=0.001,
atol=0.002,
act_str='silu')
def bench_helper(self, act_type, quant_type, rtol, atol): def bench_helper(self, act_type, quant_type, rtol, atol):
# Warm, using bfloat here since it seems to reliably use cublas. # Warm, using bfloat here since it seems to reliably use cublas.
x = random_tensor([20480, 20480], torch.bfloat16, device="cuda") x = random_tensor([20480, 20480], torch.bfloat16, device='cuda')
warm_iters = 30 warm_iters = 30
for iter in range(warm_iters): for iter in range(warm_iters):
res = x @ x res = x @ x
m_shapes = torch.arange(0, 12) m_shapes = torch.arange(0, 12)
m_shapes = 2 ** m_shapes m_shapes = 2**m_shapes
self.gemm_dequant_test_helper(act_type, quant_type, self.gemm_dequant_test_helper(act_type,
gemm_ms = [128], quant_type,
gemm_ns = [1536], gemm_ms=[128],
gemm_ks = [12288], gemm_ns=[1536],
rtol=rtol, atol=atol, benchmark=True) gemm_ks=[12288],
rtol=rtol,
atol=atol,
benchmark=True)
@unittest.skip("This is a benchmark so don't run by default") @unittest.skip("This is a benchmark so don't run by default")
def test_fp16_int8_cublas(self): def test_fp16_int8_cublas(self):
self.bench_helper(torch.float16, torch.int8, 1e-3, 0.002) self.bench_helper(torch.float16, torch.int8, 1e-3, 0.002)
@unittest.skip("This is a benchmark so don't run by default") @unittest.skip("This is a benchmark so don't run by default")
def test_bf16_int8_cublas(self): def test_bf16_int8_cublas(self):
self.bench_helper(torch.bfloat16, torch.int8, 1e-2, 1e-2) self.bench_helper(torch.bfloat16, torch.int8, 1e-2, 1e-2)
...@@ -197,10 +274,10 @@ class TestGemmDequantize(unittest.TestCase): ...@@ -197,10 +274,10 @@ class TestGemmDequantize(unittest.TestCase):
def test_fp16_int4_cublas(self): def test_fp16_int4_cublas(self):
self.bench_helper(torch.float16, torch.quint4x2, 1e-3, 0.002) self.bench_helper(torch.float16, torch.quint4x2, 1e-3, 0.002)
@unittest.skip("This is a benchmark so don't run by default") @unittest.skip("This is a benchmark so don't run by default")
def test_bf16_int4_cublas(self): def test_bf16_int4_cublas(self):
self.bench_helper(torch.bfloat16, torch.quint4x2, 1e-2, 1e-2) self.bench_helper(torch.bfloat16, torch.quint4x2, 1e-2, 1e-2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment