Unverified Commit fe46dac2 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Add lint action (#32)

* temp

* fix lint

* csrc->src

* remove clang-format

* skip .rst

* skip doc

* clang-format

version

version

* mat_B
parent e8ab4ba3
......@@ -509,10 +509,10 @@ void cublasINT8MMWrapper::SpGemm(
}
else {
// initializing MatDesc takes a lot of time
cusparseLtMatDescriptor_t matA, matB, matC;
sp_mat_A_desc_map_[mark] = matA;
sp_mat_B_desc_map_[mark] = matB;
sp_mat_C_desc_map_[mark] = matC;
cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
sp_mat_A_desc_map_[mark] = mat_A;
sp_mat_B_desc_map_[mark] = mat_B;
sp_mat_C_desc_map_[mark] = mat_C;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&sp_mat_A_desc_map_[mark],
num_A_rows,
......
......@@ -695,10 +695,10 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa,
}
else {
// initializing MatDesc takes a lot of time
cusparseLtMatDescriptor_t matA, matB, matC;
sp_mat_A_desc_map_[mark] = matA;
sp_mat_B_desc_map_[mark] = matB;
sp_mat_C_desc_map_[mark] = matC;
cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
sp_mat_A_desc_map_[mark] = mat_A;
sp_mat_B_desc_map_[mark] = mat_B;
sp_mat_C_desc_map_[mark] = mat_C;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&sp_mat_A_desc_map_[mark],
num_A_rows,
......@@ -752,9 +752,9 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k)
int num_A_cols = k;
int lda = num_A_rows;
cusparseLtMatDescriptor_t matA;
cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&matA,
&mat_A,
num_A_rows,
num_A_cols,
lda,
......@@ -763,7 +763,7 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k)
order,
CUSPARSELT_SPARSITY_50_PERCENT));
size_t compressed_size = 0;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &matA, &compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &mat_A, &compressed_size));
return compressed_size;
}
......@@ -771,11 +771,11 @@ void cublasMMWrapper::compressMatrix(const void* input, void* output, const int
{
cusparseOrder_t order = CUSPARSE_ORDER_COL;
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseLtMatDescriptor_t matA;
cusparseLtMatDescriptor_t mat_A;
unsigned alignment = 16;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&cusparselt_handle_, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &matA, true, opA, input, output, stream_))
&cusparselt_handle_, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &mat_A, true, opA, input, output, stream_))
sync_check_cuda_error();
}
......
......@@ -22,7 +22,8 @@
namespace fastertransformer {
#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
......@@ -33,26 +34,34 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#endif
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union { int8_t int8[2]; int16_t int16; };
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union { int8_t int8[2]; int16_t int16; };
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
inline __device__ __nv_bfloat162 float22bf162(const float2 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
......@@ -60,7 +69,8 @@ inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
......@@ -71,7 +81,8 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
......@@ -84,15 +95,17 @@ inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bf
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
#else
return __hadd(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
......@@ -105,15 +118,17 @@ inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bf
#endif
}
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
#else
return __hsub(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
......@@ -126,15 +141,17 @@ inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bf
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else
return __hmul(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
......@@ -149,19 +166,22 @@ inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bf
#endif
}
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
return __hfma(x, y, z);
#endif
}
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);;
fxh = __high2float(x);
;
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
return h2exp(x);
......@@ -169,17 +189,27 @@ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return bf16hadd2(x, y);
};
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
__nv_bfloat162 t; t.x = x; t.y = y; return t;
__nv_bfloat162 t;
t.x = x;
t.y = y;
return t;
}
#endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
......@@ -187,7 +217,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
......@@ -195,7 +226,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
......@@ -210,7 +242,8 @@ inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, _
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
......@@ -218,7 +251,8 @@ inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
......@@ -233,7 +267,8 @@ inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, _
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a);
......
......@@ -16,22 +16,24 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_bf16_fallbacks.cuh"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include <cuda.h>
#include <cuda_fp16.h>
namespace fastertransformer {
template<typename T>
inline __device__ T ldg(const T* val) {
inline __device__ T ldg(const T* val)
{
return __ldg(val);
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) {
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
......@@ -40,7 +42,8 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) {
}
template<>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) {
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
......@@ -51,258 +54,409 @@ inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) {
// Get type2 from type or vice versa (applied to half and bfloat16)
template<typename T>
struct TypeConverter {using Type = half2;}; // keep for generality
struct TypeConverter {
using Type = half2;
}; // keep for generality
template<>
struct TypeConverter<half2> {using Type = half;};
struct TypeConverter<half2> {
using Type = half;
};
template<>
struct TypeConverter<half> {using Type = half2;};
struct TypeConverter<half> {
using Type = half2;
};
#if ENABLE_BF16
template<>
struct TypeConverter<__nv_bfloat162> {using Type = __nv_bfloat16;};
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template<>
struct TypeConverter<__nv_bfloat16> {using Type = __nv_bfloat162;};
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#endif // ENABLE_BF16
// Defined math operations (bfloat16 fallback to fp32 when it is not supported)
template<typename T>
inline __device__ T hadd2(T a, T b) {
inline __device__ T hadd2(T a, T b)
{
return __hadd2(a, b);
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b) {
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
#endif // ENABLE_BF16
template<typename T>
inline __device__ T add(T a, T b) {
inline __device__ T add(T a, T b)
{
return a + b;
}
template<>
inline __device__ half2 add(half2 a, half2 b) {
inline __device__ half2 add(half2 a, half2 b)
{
return __hadd2(a, b);
}
template<>
inline __device__ half add(half a, half b) {
inline __device__ half add(half a, half b)
{
return __hadd(a, b);
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
template<>
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return bf16hadd(a, b);
}
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b) {
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)
{
return bf16hadd(a, __float2bfloat16(b));
}
#endif // ENABLE_BF16
// applies to all 4 values addition
template<typename T>
inline __device__ T add(T a, T b, T c) {
inline __device__ T add(T a, T b, T c)
{
return a + b + c;
}
#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hadd(a, b, c);
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hadd2(a, b, c);
}
#endif // ENABLE_BF16
// applies to all 4 values addition
template<typename T>
inline __device__ T add(T a, T b, T c, T d) {
inline __device__ T add(T a, T b, T c, T d)
{
return (T)((float)a + (float)b + (float)c + (float)d);
}
#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
return bf16hadd(a, b, c, d);
}
#endif // ENABLE_BF16
template<typename T>
inline __device__ T hsub2(T a, T b) {
inline __device__ T hsub2(T a, T b)
{
return __hsub2(a, b);
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b) {
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hsub2(a, b);
}
#endif // ENABLE_BF16
template<typename T>
inline __device__ T hmul2(T a, T b) {
inline __device__ T hmul2(T a, T b)
{
return __hmul2(a, b);
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b) {
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hmul2(a, b);
}
#endif // ENABLE_BF16
template<typename T>
inline __device__ T hmul2(T a, T b, T c) {
inline __device__ T hmul2(T a, T b, T c)
{
return a * b * c;
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template<typename T>
inline __device__ T mul(T a, T b, T c) {
inline __device__ T mul(T a, T b, T c)
{
return a * b * c;
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hmul(a, b, c);
}
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template<typename T>
inline __device__ T fma(T a, T b, T c, T d) {
inline __device__ T fma(T a, T b, T c, T d)
{
return a * b * c + d;
}
#if ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
return bf16hfma2(a, b, c, d);
}
#endif // ENABLE_BF16
template<typename T>
inline __device__ T fma(T a, T b, T c) {
inline __device__ T fma(T a, T b, T c)
{
return a * b + c;
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c);
}
template<>
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hfma(a, b, c);
}
#endif // ENABLE_BF16
template<typename T>
inline __device__ T hexp2(T a) {
inline __device__ T hexp2(T a)
{
return h2exp(a);
}
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a) {
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)
{
return bf16exp2(a);
}
#endif // ENABLE_BF16
template<typename T_OUT, typename T_IN> __device__ inline T_OUT cuda_cast(T_IN val) { return val; }
template<typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template<> __device__ inline float2 cuda_cast<float2, int2>(int2 val) { return make_float2(val.x, val.y); }
template<> __device__ inline float2 cuda_cast<float2, float>(float val) { return make_float2(val, val); }
template<> __device__ inline float2 cuda_cast<float2, half2>(half2 val) { return __half22float2(val); }
template<> __device__ inline half2 cuda_cast<half2, float2>(float2 val) { return __float22half2_rn(val); }
template<> __device__ inline half2 cuda_cast<half2, float>(float val) { return __float2half2_rn(val); }
template<> __device__ inline half2 cuda_cast<half2, half>(half val) { return __half2half2(val); }
template<>
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
{
return make_float2(val.x, val.y);
}
template<>
__device__ inline float2 cuda_cast<float2, float>(float val)
{
return make_float2(val, val);
}
template<>
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
{
return __half22float2(val);
}
template<>
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
{
return __float22half2_rn(val);
}
template<>
__device__ inline half2 cuda_cast<half2, float>(float val)
{
return __float2half2_rn(val);
}
template<>
__device__ inline half2 cuda_cast<half2, half>(half val)
{
return __half2half2(val);
}
template<> __device__ inline int8_t cuda_cast<int8_t, half>(half val) {
union { int8_t int8[2]; int16_t int16; };
union { half fp16; int16_t int16_in; };
template<>
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
{
union {
int8_t int8[2];
int16_t int16;
};
union {
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile ("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template<> __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) {
union { int8_t int8[2]; int16_t int16; };
template<>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
{
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template<> __device__ inline int8_t cuda_cast<int8_t, float>(float val) {
union { int8_t int8[2]; int16_t int16; };
asm volatile ("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
template<>
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
{
union {
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template<> __device__ inline int16_t cuda_cast<int16_t, float2>(float2 val) {
union { int8_t int8[2]; int16_t int16; };
template<>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
{
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template<> __device__ inline half2 cuda_cast<half2, int16_t>(int16_t val) {
union { int8_t int8[2]; int16_t int16; };
template<>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
{
union {
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template<> __device__ inline float2 cuda_cast<float2, int16_t>(int16_t val) {
union { int8_t int8[2]; int16_t int16; };
template<>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
{
union {
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
#ifdef ENABLE_BF16
template<> __device__ inline __nv_bfloat16 cuda_cast(int32_t val) { return static_cast<float>(val); }
template<> __device__ inline __nv_bfloat16 cuda_cast(int8_t val) { return static_cast<float>(val); }
template<> __device__ inline int8_t cuda_cast(__nv_bfloat16 val) { return static_cast<float>(val); }
template<>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) { return __bfloat162float(val); }
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast<float>(val);
}
template<>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast<float>(val);
}
template<>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast<float>(val);
}
template<> __device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val) { return bf1622float2(val); }
template<>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template<> __device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val) { return __float2half(__bfloat162float(val)); }
template<>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622float2(val);
}
template<> __device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val) { return bf1622int16(val); }
template<>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template<> __device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { return __float2bfloat16(val); }
template<> __device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) { return __float2bfloat16(__half2float(val)); }
template<>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) { return bf162bf162(val); }
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) { return __float2bfloat162_rn(val); }
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) { return float22bf162(val); }
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) {
union { int8_t int8[2]; int16_t int16; };
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
return bf162bf162(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union {
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
......@@ -310,62 +464,138 @@ template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(i
return res;
}
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) { return float22bf162(__half22float2(val)); }
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16
template<typename T> __device__ inline T cuda_abs(T val);
template<> __device__ inline float cuda_abs(float val) { return fabs(val); }
template<> __device__ inline half cuda_abs(half val) { return __habs(val); }
template<> __device__ inline half2 cuda_abs(half2 val) { return __habs2(val); }
template<typename T>
__device__ inline T cuda_abs(T val);
template<>
__device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template<>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template<>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template<> __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) { return __habs(val); }
template<> __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) { return __habs2(val); }
template<>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#else
template<> __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) { return fabs(val); }
template<> __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) { return make_bfloat162(fabs(val.x), fabs(val.y)); }
template<>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return fabs(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return make_bfloat162(fabs(val.x), fabs(val.y));
}
#endif
#endif // ENABLE_FP16
// Unary maximum: compute the max of a vector type
template<typename To, typename Ti> __device__ inline To cuda_max(Ti val)
template<typename To, typename Ti>
__device__ inline To cuda_max(Ti val)
{
return cuda_cast<To>(val);
};
template<> __device__ inline half cuda_max(half2 val) { return (val.x > val.y) ? val.x : val.y; }
template<>
__device__ inline half cuda_max(half2 val)
{
return (val.x > val.y) ? val.x : val.y;
}
#ifdef ENABLE_BF16
template<> __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) { return (val.x > val.y) ? val.x : val.y; }
template<>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
return (val.x > val.y) ? val.x : val.y;
}
#endif
// Binary maximum: compute the max of two scalar types
template<typename T> __device__ inline T cuda_max(T val1, T val2) { return (val1 > val2) ? val1 : val2; }
template<typename T>
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
#ifdef ENABLE_FP8
template<> __device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) { return bf1622float2(fp8x2_e4m3_to_bfloat2(&val)); }
template<> __device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val) { return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val))); }
template<>
__device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));
}
template<>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)
{
return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));
}
template<> __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val) { return __nv_fp8_e4m3(val); }
template<> __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val) { return __nv_fp8_e4m3(val); }
template<> __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val) { return __nv_fp8_e4m3(val); }
template<> __device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val) { return (float)val; }
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) { return fp8x2_e4m3_to_bfloat2(&val); }
template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)
{
return __nv_fp8_e4m3(val);
}
template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)
{
return __nv_fp8_e4m3(val);
}
template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)
{
return __nv_fp8_e4m3(val);
}
template<>
__device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
return (float)val;
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return fp8x2_e4m3_to_bfloat2(&val);
}
template<> __device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
template<>
__device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
// no impl
return 0;
}
template<> __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
{
return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val)));
}
#endif // ENABLE_FP8
}
} // namespace fastertransformer
......@@ -462,29 +462,29 @@ void generate_encoder_gemm_config(
T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed;
{
cusparseLtMatDescriptor_t matA;
cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
&handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size))
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream))
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
}
float exec_time = 99999.0f;
int fast_algo = 0;
for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC;
cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr;
int num_streams = 1;
cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order))
&handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) {
// initializing MatDesc takes a lot of time
......@@ -494,7 +494,7 @@ void generate_encoder_gemm_config(
cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type))
&handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
......@@ -1239,15 +1239,15 @@ int generate_encoder_igemm_config(
int8_t* d_C = d_B + k * n;
int8_t* dA_compressed;
{
cusparseLtMatDescriptor_t matA;
cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT))
&handle, &mat_A, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size))
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream))
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
}
cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError();
......@@ -1259,14 +1259,14 @@ int generate_encoder_igemm_config(
int fast_algo = 0;
for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC;
cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr;
int num_streams = 1;
cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_8I, col_order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_8I, col_order))
&handle, &mat_A, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_8I, col_order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_8I, col_order))
gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) {
// initializing MatDesc takes a lot of time
......@@ -1276,7 +1276,7 @@ int generate_encoder_igemm_config(
cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type))
&handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
......@@ -617,15 +617,15 @@ void generate_gpt_gemm_config(int batch_size,
T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed;
{
cusparseLtMatDescriptor_t matA;
cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
&handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size))
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream))
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
}
float exec_time = 99999.0f;
......@@ -633,14 +633,15 @@ void generate_gpt_gemm_config(int batch_size,
if (isSparseGemmAvailable(m, n, k)) {
for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC;
cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr;
int num_streams = 1;
cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order))
&handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(
cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
cudaDeviceSynchronize();
gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) {
......@@ -651,7 +652,7 @@ void generate_gpt_gemm_config(int batch_size,
cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type))
&handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
......@@ -616,15 +616,15 @@ void generate_t5_gemm_config(int batch_size,
T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed;
{
cusparseLtMatDescriptor_t matA;
cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
&handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size))
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream))
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
}
float exec_time = 99999.0f;
......@@ -632,14 +632,15 @@ void generate_t5_gemm_config(int batch_size,
if (isSparseGemmAvailable(m, n, k)) {
for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC;
cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr;
int num_streams = 1;
cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order))
&handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(
cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
cudaDeviceSynchronize();
gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) {
......@@ -650,7 +651,7 @@ void generate_t5_gemm_config(int batch_size,
cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type))
&handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
......@@ -27,8 +27,7 @@ namespace fastertransformer {
class Logger {
public:
enum Level
{
enum Level {
TRACE = 0,
DEBUG = 10,
INFO = 20,
......
import torch
# flake8: noqa
import unittest
import torch
def random_tensor(shape, dtype, device, mean=0, std=1):
return torch.empty(shape, dtype=dtype, device=device).normal_(mean, std)
class TestGemmDequantize(unittest.TestCase):
def setUp(self) -> None:
torch.classes.load_library("lib/libth_transformer.so")
torch.classes.load_library("lib/libgemm_dq_unit_ops.so")
torch.classes.load_library('lib/libth_transformer.so')
torch.classes.load_library('lib/libgemm_dq_unit_ops.so')
self.unpack_packed_int4s = torch.ops.fastertransformer.unpack_int4_packed_tensor_to_int8
self.pack_int4s = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
self.fused_gemm_dq = torch.ops.gemm_dq_unit_ops.fused_gemm_dq
......@@ -26,17 +31,26 @@ class TestGemmDequantize(unittest.TestCase):
upper_bound = 127 if quant_type == torch.int8 else 7
m, n, k = 64, 128, 64
weights = torch.randint(lower_bound, upper_bound, [k, n], dtype=torch.int8, device="cpu")
packed_weight = self.pack_int4s(weights) if quant_type == torch.quint4x2 else weights
cuda_weights = self.preprocess_weights_for_mixed_gemm(packed_weight, quant_type).to("cuda")
weights = weights.to("cuda")
act = torch.eye(m, dtype=weight_type, device="cuda")
weights = torch.randint(lower_bound,
upper_bound, [k, n],
dtype=torch.int8,
device='cpu')
packed_weight = self.pack_int4s(
weights) if quant_type == torch.quint4x2 else weights
cuda_weights = self.preprocess_weights_for_mixed_gemm(
packed_weight, quant_type).to('cuda')
weights = weights.to('cuda')
act = torch.eye(m, dtype=weight_type, device='cuda')
scales = torch.ones([n], dtype=weight_type, device='cuda')
actual = self.fused_gemm_dq(act, cuda_weights, scales)
torch.testing.assert_close(actual, weights, atol=0, rtol=0, check_dtype=False)
torch.testing.assert_close(actual,
weights,
atol=0,
rtol=0,
check_dtype=False)
def test_fp16_int8_dequantize(self):
self.dequantize_test_helper(torch.float16, torch.int8)
......@@ -51,144 +65,207 @@ class TestGemmDequantize(unittest.TestCase):
self.dequantize_test_helper(torch.bfloat16, torch.quint4x2)
def apply_act(self, inp, act_str):
if act_str == "identity":
if act_str == 'identity':
return inp
elif act_str == "silu":
elif act_str == 'silu':
return torch.nn.SiLU()(inp)
elif act_str == "relu":
elif act_str == 'relu':
return torch.nn.ReLU()(inp)
elif act_str == "gelu":
return torch.nn.GELU(approximate="tanh")(inp)
elif act_str == 'gelu':
return torch.nn.GELU(approximate='tanh')(inp)
else:
assert False, "Unsupported activation"
def gemm_dequant_test_helper(self, compute_type, weight_dtype, gemm_ms, gemm_ns, gemm_ks, rtol, atol, act_str="only_gemm", benchmark=False):
assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, "Weight must be quantized"
assert False, 'Unsupported activation'
def gemm_dequant_test_helper(self,
compute_type,
weight_dtype,
gemm_ms,
gemm_ns,
gemm_ks,
rtol,
atol,
act_str='only_gemm',
benchmark=False):
assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, 'Weight must be quantized'
for gemm_k in gemm_ks:
for gemm_n in gemm_ns:
torch_weights_cpu = random_tensor((gemm_k, gemm_n), dtype=compute_type, device="cpu", mean=0, std=0.002)
ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer(torch_weights_cpu, weight_dtype)
ref_torch_weights = self.unpack_packed_int4s(ref_torch_weights) if weight_dtype == torch.quint4x2 else ref_torch_weights
ref_torch_weights = ref_torch_weights.to("cuda")
processed_torch_weights = processed_torch_weights.to("cuda")
torch_weight_scales = torch_weight_scales.to("cuda")
torch_biases = random_tensor((gemm_n), dtype=compute_type, device="cuda", mean=0, std=0.1)
torch_weights_cpu = random_tensor((gemm_k, gemm_n),
dtype=compute_type,
device='cpu',
mean=0,
std=0.002)
ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer(
torch_weights_cpu, weight_dtype)
ref_torch_weights = self.unpack_packed_int4s(
ref_torch_weights
) if weight_dtype == torch.quint4x2 else ref_torch_weights
ref_torch_weights = ref_torch_weights.to('cuda')
processed_torch_weights = processed_torch_weights.to('cuda')
torch_weight_scales = torch_weight_scales.to('cuda')
torch_biases = random_tensor((gemm_n),
dtype=compute_type,
device='cuda',
mean=0,
std=0.1)
for num_rows in gemm_ms:
torch_activations = torch.randn(size=(num_rows, gemm_k), dtype=compute_type, device="cuda")
torch_activations = torch.randn(size=(num_rows, gemm_k),
dtype=compute_type,
device='cuda')
scales_unsqueezed = torch_weight_scales.unsqueeze(0)
casted_weights = ref_torch_weights.to(torch_activations.dtype)
dequantized_weights = torch.multiply(casted_weights, scales_unsqueezed)
casted_weights = ref_torch_weights.to(
torch_activations.dtype)
dequantized_weights = torch.multiply(
casted_weights, scales_unsqueezed)
if benchmark:
assert act_str == "only_gemm", "Benchmarks against cublas must use just GEMM."
assert act_str == 'only_gemm', 'Benchmarks against cublas must use just GEMM.'
torch.cuda.profiler.start()
times, results = self.bench(torch_activations, processed_torch_weights, torch_weight_scales, dequantized_weights, 200)
times, results = self.bench(torch_activations,
processed_torch_weights,
torch_weight_scales,
dequantized_weights, 200)
torch.cuda.profiler.stop()
times = times[0]
cublas_time = times[0].item()
ft_time = times[1].item()
ft_speedup = cublas_time / ft_time
print("{},{},{},{},{},{}".format(num_rows, gemm_n, gemm_k, cublas_time, ft_time, ft_speedup))
print('{},{},{},{},{},{}'.format(
num_rows, gemm_n, gemm_k, cublas_time, ft_time,
ft_speedup))
reference_result = results[0]
ft_result = results[1]
else:
if act_str == "only_gemm":
reference_result = torch.matmul(torch_activations, dequantized_weights)
ft_result = self.fused_gemm_dq(torch_activations, processed_torch_weights, torch_weight_scales)
if act_str == 'only_gemm':
reference_result = torch.matmul(
torch_activations, dequantized_weights)
ft_result = self.fused_gemm_dq(
torch_activations, processed_torch_weights,
torch_weight_scales)
else:
reference_result = torch.matmul(torch_activations, dequantized_weights)
reference_result = torch.matmul(
torch_activations, dequantized_weights)
reference_result += torch_biases.unsqueeze(0)
reference_result = self.apply_act(reference_result, act_str)
ft_result = self.fused_gemm_dq_bias_act(torch_activations, processed_torch_weights, torch_weight_scales, torch_biases, act_str)
msg = "FC1 Failed on m={}, n={}, k={}".format(num_rows, gemm_n, gemm_k)
torch.testing.assert_close(ft_result, reference_result, rtol=rtol, atol=atol, msg=msg, check_dtype=False)
reference_result = self.apply_act(
reference_result, act_str)
ft_result = self.fused_gemm_dq_bias_act(
torch_activations, processed_torch_weights,
torch_weight_scales, torch_biases, act_str)
msg = 'FC1 Failed on m={}, n={}, k={}'.format(
num_rows, gemm_n, gemm_k)
torch.testing.assert_close(ft_result,
reference_result,
rtol=rtol,
atol=atol,
msg=msg,
check_dtype=False)
def test_fp16_int8_gemm(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8,
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1],
gemm_ns = [1024, 2048, 4096],
gemm_ks = [4096, 8192, 16384],
rtol=0.001, atol=0.002)
self.gemm_dequant_test_helper(
torch.float16,
torch.int8,
gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.001,
atol=0.002)
def test_fp16_int4_gemm(self):
self.gemm_dequant_test_helper(torch.float16, torch.quint4x2,
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1],
gemm_ns = [1024, 2048, 4096],
gemm_ks = [4096, 8192, 16384],
rtol=0.001, atol=0.002)
self.gemm_dequant_test_helper(
torch.float16,
torch.quint4x2,
gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.001,
atol=0.002)
def test_bf16_int8_gemm(self):
self.gemm_dequant_test_helper(torch.bfloat16, torch.int8,
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1],
gemm_ns = [1024, 2048, 4096],
gemm_ks = [4096, 8192, 16384],
rtol=0.01, atol=0.01)
self.gemm_dequant_test_helper(
torch.bfloat16,
torch.int8,
gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.01,
atol=0.01)
def test_bf16_int4_gemm(self):
self.gemm_dequant_test_helper(torch.bfloat16, torch.quint4x2,
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1],
gemm_ns = [1024, 2048, 4096],
gemm_ks = [4096, 8192, 16384],
rtol=0.01, atol=0.01)
self.gemm_dequant_test_helper(
torch.bfloat16,
torch.quint4x2,
gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.01,
atol=0.01)
def test_fp16_int8_gemm_bias(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8,
gemm_ms = [256],
gemm_ns = [1024],
gemm_ks = [8192],
rtol=0.001, atol=0.002,
act_str="identity")
self.gemm_dequant_test_helper(torch.float16,
torch.int8,
gemm_ms=[256],
gemm_ns=[1024],
gemm_ks=[8192],
rtol=0.001,
atol=0.002,
act_str='identity')
def test_fp16_int8_gemm_bias_relu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8,
gemm_ms = [256],
gemm_ns = [1024],
gemm_ks = [8192],
rtol=0.001, atol=0.002,
act_str="relu")
self.gemm_dequant_test_helper(torch.float16,
torch.int8,
gemm_ms=[256],
gemm_ns=[1024],
gemm_ks=[8192],
rtol=0.001,
atol=0.002,
act_str='relu')
def test_fp16_int8_gemm_bias_gelu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8,
gemm_ms = [256],
gemm_ns = [1024],
gemm_ks = [8192],
rtol=0.001, atol=0.002,
act_str="gelu")
self.gemm_dequant_test_helper(torch.float16,
torch.int8,
gemm_ms=[256],
gemm_ns=[1024],
gemm_ks=[8192],
rtol=0.001,
atol=0.002,
act_str='gelu')
def test_fp16_int8_gemm_bias_silu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8,
gemm_ms = [256],
gemm_ns = [1024],
gemm_ks = [8192],
rtol=0.001, atol=0.002,
act_str="silu")
self.gemm_dequant_test_helper(torch.float16,
torch.int8,
gemm_ms=[256],
gemm_ns=[1024],
gemm_ks=[8192],
rtol=0.001,
atol=0.002,
act_str='silu')
def bench_helper(self, act_type, quant_type, rtol, atol):
# Warm, using bfloat here since it seems to reliably use cublas.
x = random_tensor([20480, 20480], torch.bfloat16, device="cuda")
x = random_tensor([20480, 20480], torch.bfloat16, device='cuda')
warm_iters = 30
for iter in range(warm_iters):
res = x @ x
m_shapes = torch.arange(0, 12)
m_shapes = 2 ** m_shapes
m_shapes = 2**m_shapes
self.gemm_dequant_test_helper(act_type, quant_type,
gemm_ms = [128],
gemm_ns = [1536],
gemm_ks = [12288],
rtol=rtol, atol=atol, benchmark=True)
self.gemm_dequant_test_helper(act_type,
quant_type,
gemm_ms=[128],
gemm_ns=[1536],
gemm_ks=[12288],
rtol=rtol,
atol=atol,
benchmark=True)
@unittest.skip("This is a benchmark so don't run by default")
def test_fp16_int8_cublas(self):
self.bench_helper(torch.float16, torch.int8, 1e-3, 0.002)
@unittest.skip("This is a benchmark so don't run by default")
def test_bf16_int8_cublas(self):
self.bench_helper(torch.bfloat16, torch.int8, 1e-2, 1e-2)
......@@ -197,10 +274,10 @@ class TestGemmDequantize(unittest.TestCase):
def test_fp16_int4_cublas(self):
self.bench_helper(torch.float16, torch.quint4x2, 1e-3, 0.002)
@unittest.skip("This is a benchmark so don't run by default")
def test_bf16_int4_cublas(self):
self.bench_helper(torch.bfloat16, torch.quint4x2, 1e-2, 1e-2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment