Unverified Commit fe46dac2 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Add lint action (#32)

* temp

* fix lint

* csrc->src

* remove clang-format

* skip .rst

* skip doc

* clang-format

version

version

* mat_B
parent e8ab4ba3
/* /*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#pragma once #pragma once
#include "stdio.h" #include "stdio.h"
#include "stdlib.h" #include "stdlib.h"
// be consistent with FasterTransformer // be consistent with FasterTransformer
int8_t float_to_int8_rn_host(float x) int8_t float_to_int8_rn_host(float x)
{ {
int8_t res; int8_t res;
int32_t tmp; int32_t tmp;
if (x >= 0) { if (x >= 0) {
tmp = int(x + 0.5); tmp = int(x + 0.5);
tmp = tmp > 127 ? 127 : tmp; tmp = tmp > 127 ? 127 : tmp;
res = int8_t(tmp); res = int8_t(tmp);
} }
else { else {
tmp = int(x - 0.5); tmp = int(x - 0.5);
tmp = tmp < -127 ? -127 : tmp; tmp = tmp < -127 ? -127 : tmp;
res = int8_t(tmp); res = int8_t(tmp);
} }
return res; return res;
} }
\ No newline at end of file
...@@ -509,10 +509,10 @@ void cublasINT8MMWrapper::SpGemm( ...@@ -509,10 +509,10 @@ void cublasINT8MMWrapper::SpGemm(
} }
else { else {
// initializing MatDesc takes a lot of time // initializing MatDesc takes a lot of time
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
sp_mat_A_desc_map_[mark] = matA; sp_mat_A_desc_map_[mark] = mat_A;
sp_mat_B_desc_map_[mark] = matB; sp_mat_B_desc_map_[mark] = mat_B;
sp_mat_C_desc_map_[mark] = matC; sp_mat_C_desc_map_[mark] = mat_C;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&sp_mat_A_desc_map_[mark], &sp_mat_A_desc_map_[mark],
num_A_rows, num_A_rows,
......
...@@ -695,10 +695,10 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa, ...@@ -695,10 +695,10 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa,
} }
else { else {
// initializing MatDesc takes a lot of time // initializing MatDesc takes a lot of time
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
sp_mat_A_desc_map_[mark] = matA; sp_mat_A_desc_map_[mark] = mat_A;
sp_mat_B_desc_map_[mark] = matB; sp_mat_B_desc_map_[mark] = mat_B;
sp_mat_C_desc_map_[mark] = matC; sp_mat_C_desc_map_[mark] = mat_C;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&sp_mat_A_desc_map_[mark], &sp_mat_A_desc_map_[mark],
num_A_rows, num_A_rows,
...@@ -752,9 +752,9 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) ...@@ -752,9 +752,9 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k)
int num_A_cols = k; int num_A_cols = k;
int lda = num_A_rows; int lda = num_A_rows;
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_,
&matA, &mat_A,
num_A_rows, num_A_rows,
num_A_cols, num_A_cols,
lda, lda,
...@@ -763,7 +763,7 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) ...@@ -763,7 +763,7 @@ size_t cublasMMWrapper::getSparseMatrixSize(int m, int k)
order, order,
CUSPARSELT_SPARSITY_50_PERCENT)); CUSPARSELT_SPARSITY_50_PERCENT));
size_t compressed_size = 0; size_t compressed_size = 0;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &matA, &compressed_size)); CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &mat_A, &compressed_size));
return compressed_size; return compressed_size;
} }
...@@ -771,11 +771,11 @@ void cublasMMWrapper::compressMatrix(const void* input, void* output, const int ...@@ -771,11 +771,11 @@ void cublasMMWrapper::compressMatrix(const void* input, void* output, const int
{ {
cusparseOrder_t order = CUSPARSE_ORDER_COL; cusparseOrder_t order = CUSPARSE_ORDER_COL;
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
unsigned alignment = 16; unsigned alignment = 16;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&cusparselt_handle_, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &cusparselt_handle_, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &matA, true, opA, input, output, stream_)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &mat_A, true, opA, input, output, stream_))
sync_check_cuda_error(); sync_check_cuda_error();
} }
......
...@@ -22,10 +22,11 @@ ...@@ -22,10 +22,11 @@
namespace fastertransformer { namespace fastertransformer {
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val; float2 f_val;
f_val.x = __low2float(val); f_val.x = __low2float(val);
f_val.y = __high2float(val); f_val.y = __high2float(val);
return f_val; return f_val;
#else #else
...@@ -33,26 +34,34 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { ...@@ -33,26 +34,34 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#endif #endif
} }
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val; float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f); f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f); f_val.y = max(min(__high2float(val), 127.f), -128.f);
union { int8_t int8[2]; int16_t int16; }; union {
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x)); int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y)); int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16; return int16;
#else #else
val = __hmin2(val, make_bfloat162(127., 127.)); val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.)); val = __hmax2(val, make_bfloat162(-128., -128.));
union { int8_t int8[2]; int16_t int16; }; union {
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x)); int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y)); int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16; return int16;
#endif #endif
} }
inline __device__ __nv_bfloat162 float22bf162(const float2 val) { inline __device__ __nv_bfloat162 float22bf162(const float2 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y); return __floats2bfloat162_rn(val.x, val.y);
#else #else
...@@ -60,7 +69,8 @@ inline __device__ __nv_bfloat162 float22bf162(const float2 val) { ...@@ -60,7 +69,8 @@ inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
#endif #endif
} }
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2; __nv_bfloat162 val2;
val2.x = val; val2.x = val;
...@@ -71,7 +81,8 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { ...@@ -71,7 +81,8 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh; float fxl, fxh, fyl, fyh;
fxl = __low2float(x); fxl = __low2float(x);
...@@ -84,15 +95,17 @@ inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bf ...@@ -84,15 +95,17 @@ inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
#else #else
return __hadd(x, y); return __hadd(x, y);
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh; float fxl, fxh, fyl, fyh;
fxl = __low2float(x); fxl = __low2float(x);
...@@ -105,15 +118,17 @@ inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bf ...@@ -105,15 +118,17 @@ inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
#else #else
return __hsub(x, y); return __hsub(x, y);
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh; float fxl, fxh, fyl, fyh;
fxl = __low2float(x); fxl = __low2float(x);
...@@ -126,15 +141,17 @@ inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bf ...@@ -126,15 +141,17 @@ inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else #else
return __hmul(x, y); return __hmul(x, y);
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh; float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x); fxl = __low2float(x);
...@@ -149,19 +166,22 @@ inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bf ...@@ -149,19 +166,22 @@ inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else #else
return __hfma(x, y, z); return __hfma(x, y, z);
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh; float fxl, fxh;
fxl = __low2float(x); fxl = __low2float(x);
fxh = __high2float(x);; fxh = __high2float(x);
;
return __floats2bfloat162_rn(expf(fxl), expf(fxh)); return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else #else
return h2exp(x); return h2exp(x);
...@@ -169,17 +189,27 @@ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { ...@@ -169,17 +189,27 @@ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
} }
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; {
return bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return bf16hadd2(x, y);
};
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{ {
__nv_bfloat162 t; t.x = x; t.y = y; return t; __nv_bfloat162 t;
t.x = x;
t.y = y;
return t;
} }
#endif #endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else #else
...@@ -187,7 +217,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_ ...@@ -187,7 +217,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else #else
...@@ -195,7 +226,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_ ...@@ -195,7 +226,8 @@ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch; float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a); fal = __low2float(a);
...@@ -210,7 +242,8 @@ inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, _ ...@@ -210,7 +242,8 @@ inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, _
#endif #endif
} }
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else #else
...@@ -218,7 +251,8 @@ inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_ ...@@ -218,7 +251,8 @@ inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch; float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a); fal = __low2float(a);
...@@ -233,7 +267,8 @@ inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, _ ...@@ -233,7 +267,8 @@ inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, _
#endif #endif
} }
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a); fal = __low2float(a);
...@@ -250,6 +285,6 @@ inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, _ ...@@ -250,6 +285,6 @@ inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, _
#endif #endif
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -18,4 +18,4 @@ ...@@ -18,4 +18,4 @@
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
#include <cuda_bf16.h> #include <cuda_bf16.h>
#endif #endif
\ No newline at end of file
...@@ -121,4 +121,4 @@ template void ...@@ -121,4 +121,4 @@ template void
invokeComputeFP8QuantizeScale(float* quant_ptr, const float* weights, const int k, const int n, cudaStream_t stream); invokeComputeFP8QuantizeScale(float* quant_ptr, const float* weights, const int k, const int n, cudaStream_t stream);
#endif // ENABLE_FP8 #endif // ENABLE_FP8
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -16,22 +16,24 @@ ...@@ -16,22 +16,24 @@
#pragma once #pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_bf16_fallbacks.cuh" #include "src/fastertransformer/utils/cuda_bf16_fallbacks.cuh"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include <cuda.h>
#include <cuda_fp16.h>
namespace fastertransformer { namespace fastertransformer {
template<typename T> template<typename T>
inline __device__ T ldg(const T* val) { inline __device__ T ldg(const T* val)
{
return __ldg(val); return __ldg(val);
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) { inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0]; return val[0];
#else #else
...@@ -40,269 +42,421 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) { ...@@ -40,269 +42,421 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) {
} }
template<> template<>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) { inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0]; return val[0];
#else #else
return __ldg(val); return __ldg(val);
#endif #endif
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
// Get type2 from type or vice versa (applied to half and bfloat16) // Get type2 from type or vice versa (applied to half and bfloat16)
template<typename T> template<typename T>
struct TypeConverter {using Type = half2;}; // keep for generality struct TypeConverter {
using Type = half2;
}; // keep for generality
template<> template<>
struct TypeConverter<half2> {using Type = half;}; struct TypeConverter<half2> {
using Type = half;
};
template<> template<>
struct TypeConverter<half> {using Type = half2;}; struct TypeConverter<half> {
using Type = half2;
};
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
struct TypeConverter<__nv_bfloat162> {using Type = __nv_bfloat16;}; struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template<> template<>
struct TypeConverter<__nv_bfloat16> {using Type = __nv_bfloat162;}; struct TypeConverter<__nv_bfloat16> {
#endif // ENABLE_BF16 using Type = __nv_bfloat162;
};
#endif // ENABLE_BF16
// Defined math operations (bfloat16 fallback to fp32 when it is not supported) // Defined math operations (bfloat16 fallback to fp32 when it is not supported)
template<typename T> template<typename T>
inline __device__ T hadd2(T a, T b) { inline __device__ T hadd2(T a, T b)
{
return __hadd2(a, b); return __hadd2(a, b);
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b); return bf16hadd2(a, b);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T> template<typename T>
inline __device__ T add(T a, T b) { inline __device__ T add(T a, T b)
{
return a + b; return a + b;
} }
template<> template<>
inline __device__ half2 add(half2 a, half2 b) { inline __device__ half2 add(half2 a, half2 b)
{
return __hadd2(a, b); return __hadd2(a, b);
} }
template<> template<>
inline __device__ half add(half a, half b) { inline __device__ half add(half a, half b)
{
return __hadd(a, b); return __hadd(a, b);
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b); return bf16hadd2(a, b);
} }
template<> template<>
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return bf16hadd(a, b); return bf16hadd(a, b);
} }
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b) { inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)
{
return bf16hadd(a, __float2bfloat16(b)); return bf16hadd(a, __float2bfloat16(b));
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
// applies to all 4 values addition // applies to all 4 values addition
template<typename T> template<typename T>
inline __device__ T add(T a, T b, T c) { inline __device__ T add(T a, T b, T c)
{
return a + b + c; return a + b + c;
} }
#if ENABLE_BF16 #if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hadd(a, b, c); return bf16hadd(a, b, c);
} }
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hadd2(a, b, c); return bf16hadd2(a, b, c);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
// applies to all 4 values addition // applies to all 4 values addition
template<typename T> template<typename T>
inline __device__ T add(T a, T b, T c, T d) { inline __device__ T add(T a, T b, T c, T d)
{
return (T)((float)a + (float)b + (float)c + (float)d); return (T)((float)a + (float)b + (float)c + (float)d);
} }
#if ENABLE_BF16 #if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
return bf16hadd(a, b, c, d); return bf16hadd(a, b, c, d);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T> template<typename T>
inline __device__ T hsub2(T a, T b) { inline __device__ T hsub2(T a, T b)
{
return __hsub2(a, b); return __hsub2(a, b);
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hsub2(a, b); return bf16hsub2(a, b);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T> template<typename T>
inline __device__ T hmul2(T a, T b) { inline __device__ T hmul2(T a, T b)
{
return __hmul2(a, b); return __hmul2(a, b);
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hmul2(a, b); return bf16hmul2(a, b);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T> template<typename T>
inline __device__ T hmul2(T a, T b, T c) { inline __device__ T hmul2(T a, T b, T c)
{
return a * b * c; return a * b * c;
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c); return bf16hmul2(a, b, c);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T> template<typename T>
inline __device__ T mul(T a, T b, T c) { inline __device__ T mul(T a, T b, T c)
{
return a * b * c; return a * b * c;
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hmul(a, b, c); return bf16hmul(a, b, c);
} }
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c); return bf16hmul2(a, b, c);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T> template<typename T>
inline __device__ T fma(T a, T b, T c, T d) { inline __device__ T fma(T a, T b, T c, T d)
{
return a * b * c + d; return a * b * c + d;
} }
#if ENABLE_BF16 #if ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
return bf16hfma2(a, b, c, d); return bf16hfma2(a, b, c, d);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T> template<typename T>
inline __device__ T fma(T a, T b, T c) { inline __device__ T fma(T a, T b, T c)
{
return a * b + c; return a * b + c;
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c); return bf16hfma2(a, b, c);
} }
template<> template<>
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hfma(a, b, c); return bf16hfma(a, b, c);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T> template<typename T>
inline __device__ T hexp2(T a) { inline __device__ T hexp2(T a)
{
return h2exp(a); return h2exp(a);
} }
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a) { inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)
{
return bf16exp2(a); return bf16exp2(a);
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template<typename T_OUT, typename T_IN> __device__ inline T_OUT cuda_cast(T_IN val) { return val; } template<typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template<> __device__ inline float2 cuda_cast<float2, int2>(int2 val) { return make_float2(val.x, val.y); } template<>
template<> __device__ inline float2 cuda_cast<float2, float>(float val) { return make_float2(val, val); } __device__ inline float2 cuda_cast<float2, int2>(int2 val)
template<> __device__ inline float2 cuda_cast<float2, half2>(half2 val) { return __half22float2(val); } {
template<> __device__ inline half2 cuda_cast<half2, float2>(float2 val) { return __float22half2_rn(val); } return make_float2(val.x, val.y);
template<> __device__ inline half2 cuda_cast<half2, float>(float val) { return __float2half2_rn(val); } }
template<> __device__ inline half2 cuda_cast<half2, half>(half val) { return __half2half2(val); } template<>
__device__ inline float2 cuda_cast<float2, float>(float val)
{
return make_float2(val, val);
}
template<>
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
{
return __half22float2(val);
}
template<>
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
{
return __float22half2_rn(val);
}
template<>
__device__ inline half2 cuda_cast<half2, float>(float val)
{
return __float2half2_rn(val);
}
template<>
__device__ inline half2 cuda_cast<half2, half>(half val)
{
return __half2half2(val);
}
template<> __device__ inline int8_t cuda_cast<int8_t, half>(half val) { template<>
union { int8_t int8[2]; int16_t int16; }; __device__ inline int8_t cuda_cast<int8_t, half>(half val)
union { half fp16; int16_t int16_in; }; {
union {
int8_t int8[2];
int16_t int16;
};
union {
half fp16;
int16_t int16_in;
};
fp16 = val; fp16 = val;
asm volatile ("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0]; return int8[0];
} }
template<> __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) { template<>
union { int8_t int8[2]; int16_t int16; }; __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
{
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x); int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y); int8[1] = cuda_cast<int8_t>(val.y);
return int16; return int16;
} }
template<> __device__ inline int8_t cuda_cast<int8_t, float>(float val) { template<>
union { int8_t int8[2]; int16_t int16; }; __device__ inline int8_t cuda_cast<int8_t, float>(float val)
asm volatile ("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); {
union {
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0]; return int8[0];
} }
template<> __device__ inline int16_t cuda_cast<int16_t, float2>(float2 val) { template<>
union { int8_t int8[2]; int16_t int16; }; __device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
{
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x); int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y); int8[1] = cuda_cast<int8_t>(val.y);
return int16; return int16;
} }
template<> __device__ inline half2 cuda_cast<half2, int16_t>(int16_t val) { template<>
union { int8_t int8[2]; int16_t int16; }; __device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
{
union {
int8_t int8[2];
int16_t int16;
};
int16 = val; int16 = val;
return make_half2(int8[0], int8[1]); return make_half2(int8[0], int8[1]);
} }
template<> __device__ inline float2 cuda_cast<float2, int16_t>(int16_t val) { template<>
union { int8_t int8[2]; int16_t int16; }; __device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
{
union {
int8_t int8[2];
int16_t int16;
};
int16 = val; int16 = val;
return make_float2(int8[0], int8[1]); return make_float2(int8[0], int8[1]);
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template<> __device__ inline __nv_bfloat16 cuda_cast(int32_t val) { return static_cast<float>(val); }
template<> __device__ inline __nv_bfloat16 cuda_cast(int8_t val) { return static_cast<float>(val); }
template<> __device__ inline int8_t cuda_cast(__nv_bfloat16 val) { return static_cast<float>(val); }
template<> template<>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) { return __bfloat162float(val); } __device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast<float>(val);
}
template<>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast<float>(val);
}
template<>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast<float>(val);
}
template<> __device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val) { return bf1622float2(val); } template<>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template<> __device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val) { return __float2half(__bfloat162float(val)); } template<>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622float2(val);
}
template<> __device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val) { return bf1622int16(val); } template<>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template<> __device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { return __float2bfloat16(val); } template<>
template<> __device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) { return __float2bfloat16(__half2float(val)); } __device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) { return bf162bf162(val); } template<>
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) { return __float2bfloat162_rn(val); } __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) { return float22bf162(val); } {
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) { return bf162bf162(val);
union { int8_t int8[2]; int16_t int16; }; }
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union {
int8_t int8[2];
int16_t int16;
};
int16 = val; int16 = val;
__nv_bfloat162 res; __nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]); res.x = cuda_cast<__nv_bfloat16>(int8[0]);
...@@ -310,62 +464,138 @@ template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(i ...@@ -310,62 +464,138 @@ template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(i
return res; return res;
} }
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) { return float22bf162(__half22float2(val)); } template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16 #endif // ENABLE BF16
template<typename T> __device__ inline T cuda_abs(T val); template<typename T>
template<> __device__ inline float cuda_abs(float val) { return fabs(val); } __device__ inline T cuda_abs(T val);
template<> __device__ inline half cuda_abs(half val) { return __habs(val); } template<>
template<> __device__ inline half2 cuda_abs(half2 val) { return __habs2(val); } __device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template<>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template<>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) #if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template<> __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) { return __habs(val); } template<>
template<> __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) { return __habs2(val); } __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#else #else
template<> __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) { return fabs(val); } template<>
template<> __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) { return make_bfloat162(fabs(val.x), fabs(val.y)); } __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return fabs(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return make_bfloat162(fabs(val.x), fabs(val.y));
}
#endif #endif
#endif // ENABLE_FP16 #endif // ENABLE_FP16
// Unary maximum: compute the max of a vector type // Unary maximum: compute the max of a vector type
template<typename To, typename Ti> __device__ inline To cuda_max(Ti val) template<typename To, typename Ti>
__device__ inline To cuda_max(Ti val)
{ {
return cuda_cast<To>(val); return cuda_cast<To>(val);
}; };
template<> __device__ inline half cuda_max(half2 val) { return (val.x > val.y) ? val.x : val.y; } template<>
__device__ inline half cuda_max(half2 val)
{
return (val.x > val.y) ? val.x : val.y;
}
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template<> __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) { return (val.x > val.y) ? val.x : val.y; } template<>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
return (val.x > val.y) ? val.x : val.y;
}
#endif #endif
// Binary maximum: compute the max of two scalar types // Binary maximum: compute the max of two scalar types
template<typename T> __device__ inline T cuda_max(T val1, T val2) { return (val1 > val2) ? val1 : val2; } template<typename T>
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template<> __device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) { return bf1622float2(fp8x2_e4m3_to_bfloat2(&val)); } template<>
template<> __device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val) { return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val))); } __device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));
}
template<>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)
{
return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));
}
template<> __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val) { return __nv_fp8_e4m3(val); } template<>
template<> __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val) { return __nv_fp8_e4m3(val); } __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)
template<> __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val) { return __nv_fp8_e4m3(val); } {
template<> __device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val) { return (float)val; } return __nv_fp8_e4m3(val);
template<> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) { return fp8x2_e4m3_to_bfloat2(&val); } }
template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)
{
return __nv_fp8_e4m3(val);
}
template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)
{
return __nv_fp8_e4m3(val);
}
template<>
__device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
return (float)val;
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return fp8x2_e4m3_to_bfloat2(&val);
}
template<> __device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val) template<>
__device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{ {
// no impl // no impl
return 0; return 0;
} }
template<> __device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val) template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
{ {
return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val))); return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val)));
} }
#endif // ENABLE_FP8 #endif // ENABLE_FP8
} } // namespace fastertransformer
...@@ -84,4 +84,4 @@ struct CustomARCommTypeConverter<__nv_bfloat16> { ...@@ -84,4 +84,4 @@ struct CustomARCommTypeConverter<__nv_bfloat16> {
}; };
#endif #endif
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -462,29 +462,29 @@ void generate_encoder_gemm_config( ...@@ -462,29 +462,29 @@ void generate_encoder_gemm_config(
T* d_C = d_B + k * n * batchCount[i]; T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed; T* dA_compressed;
{ {
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size; size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size)) CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size)); check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
} }
float exec_time = 99999.0f; float exec_time = 99999.0f;
int fast_algo = 0; int fast_algo = 0;
for (int alg = 0; alg < 4; ++alg) { for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr; void* d_workspace = nullptr;
int num_streams = 1; int num_streams = 1;
cudaStream_t streams[1] = {stream}; cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) { for (int ite = 0; ite < ites; ++ite) {
// initializing MatDesc takes a lot of time // initializing MatDesc takes a lot of time
...@@ -494,7 +494,7 @@ void generate_encoder_gemm_config( ...@@ -494,7 +494,7 @@ void generate_encoder_gemm_config(
cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan; cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
/* /*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "encoder_igemm_func.h" #include "encoder_igemm_func.h"
#ifndef CUDART_VERSION #ifndef CUDART_VERSION
#error CUDART_VERSION Undefined! #error CUDART_VERSION Undefined!
#endif #endif
namespace fastertransformer { namespace fastertransformer {
int batch_size_; int batch_size_;
int seq_len_; int seq_len_;
int head_num_; int head_num_;
int size_per_head_; int size_per_head_;
static const char* showStatus(cublasStatus_t error) static const char* showStatus(cublasStatus_t error)
{ {
switch (error) { switch (error) {
case CUBLAS_STATUS_SUCCESS: case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS"; return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED"; return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED"; return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE"; return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH"; return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR"; return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED"; return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR"; return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED"; return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR: case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR"; return "CUBLAS_STATUS_LICENSE_ERROR";
} }
return "<unknown>"; return "<unknown>";
} }
// Utility function to print customMatmulPerf_t structure // Utility function to print customMatmulPerf_t structure
int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE* fout, int hasPrint) int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE* fout, int hasPrint)
{ {
int algoId, tile, swizzle, customOption, numSplitsK, reductionScheme, stages; int algoId, tile, swizzle, customOption, numSplitsK, reductionScheme, stages;
const cublasLtMatmulAlgo_t* matmulAlgo = &perf.algo; const cublasLtMatmulAlgo_t* matmulAlgo = &perf.algo;
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_ID, &algoId, sizeof(algoId), NULL); cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_ID, &algoId, sizeof(algoId), NULL);
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tile, sizeof(tile), NULL); cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tile, sizeof(tile), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &numSplitsK, sizeof(numSplitsK), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &numSplitsK, sizeof(numSplitsK), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &reductionScheme, sizeof(reductionScheme), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &reductionScheme, sizeof(reductionScheme), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL);
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else #else
stages = 0; stages = 0;
#endif #endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
"time %f workspace=%d mathMode=%d waves=%f\n", "time %f workspace=%d mathMode=%d waves=%f\n",
algoId, algoId,
tile, tile,
matmulTileName[tile], matmulTileName[tile],
numSplitsK, numSplitsK,
reductionScheme, reductionScheme,
swizzle, swizzle,
customOption, customOption,
stages, stages,
perf.status, perf.status,
perf.time, perf.time,
(int)perf.workspaceSize, (int)perf.workspaceSize,
(int)perf.mathMode, (int)perf.mathMode,
perf.wavesCount); perf.wavesCount);
// chose the fastest algo that does not need workspace // chose the fastest algo that does not need workspace
if ((int)perf.workspaceSize == 0 && hasPrint == 0) { if ((int)perf.workspaceSize == 0 && hasPrint == 0) {
fprintf(fout, fprintf(fout,
"%d %d %d %d %d ### 1 %d %d %d %d %d %d %d %d %d %d %d %f\n", "%d %d %d %d %d ### 1 %d %d %d %d %d %d %d %d %d %d %d %f\n",
batch_size_, batch_size_,
seq_len_, seq_len_,
head_num_, head_num_,
size_per_head_, size_per_head_,
INT8_DATATYPE, INT8_DATATYPE,
m, m,
n, n,
k, k,
algoId, algoId,
customOption, customOption,
tile, tile,
numSplitsK, numSplitsK,
swizzle, swizzle,
reductionScheme, reductionScheme,
(int)perf.workspaceSize, (int)perf.workspaceSize,
stages, stages,
perf.time); perf.time);
return 1; return 1;
} }
else { else {
return hasPrint; return hasPrint;
} }
} }
int printBatchPerfStructure( int printBatchPerfStructure(
int batchCount, int m, int n, int k, const customMatmulPerf_t& perf, FILE* fout, int hasPrint) int batchCount, int m, int n, int k, const customMatmulPerf_t& perf, FILE* fout, int hasPrint)
{ {
int algoId, tile, swizzle, customOption, numSplitsK, reductionScheme, stages; int algoId, tile, swizzle, customOption, numSplitsK, reductionScheme, stages;
const cublasLtMatmulAlgo_t* matmulAlgo = &perf.algo; const cublasLtMatmulAlgo_t* matmulAlgo = &perf.algo;
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_ID, &algoId, sizeof(algoId), NULL); cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_ID, &algoId, sizeof(algoId), NULL);
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tile, sizeof(tile), NULL); cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tile, sizeof(tile), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &numSplitsK, sizeof(numSplitsK), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &numSplitsK, sizeof(numSplitsK), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &reductionScheme, sizeof(reductionScheme), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &reductionScheme, sizeof(reductionScheme), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL);
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else #else
stages = 0; stages = 0;
#endif #endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
"time %f workspace=%d mathMode=%d waves=%f\n", "time %f workspace=%d mathMode=%d waves=%f\n",
algoId, algoId,
tile, tile,
matmulTileName[tile], matmulTileName[tile],
numSplitsK, numSplitsK,
reductionScheme, reductionScheme,
swizzle, swizzle,
customOption, customOption,
stages, stages,
perf.status, perf.status,
perf.time, perf.time,
(int)perf.workspaceSize, (int)perf.workspaceSize,
(int)perf.mathMode, (int)perf.mathMode,
perf.wavesCount); perf.wavesCount);
// chose the fastest algo that does not need workspace // chose the fastest algo that does not need workspace
if ((int)perf.workspaceSize == 0 && hasPrint == 0) { if ((int)perf.workspaceSize == 0 && hasPrint == 0) {
fprintf(fout, fprintf(fout,
"%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d %f\n", "%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d %f\n",
batch_size_, batch_size_,
seq_len_, seq_len_,
head_num_, head_num_,
size_per_head_, size_per_head_,
INT8_DATATYPE, INT8_DATATYPE,
batchCount, batchCount,
m, m,
n, n,
k, k,
algoId, algoId,
customOption, customOption,
tile, tile,
numSplitsK, numSplitsK,
swizzle, swizzle,
reductionScheme, reductionScheme,
(int)perf.workspaceSize, (int)perf.workspaceSize,
stages, stages,
perf.time); perf.time);
return 1; return 1;
} }
else { else {
return hasPrint; return hasPrint;
} }
} }
static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMatmulPerf_t& perf_b) static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMatmulPerf_t& perf_b)
{ {
return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time)); return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time));
} }
static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU)
cublasLtMatmulDesc_t operationDesc, cublasLtMatmulDesc_t operationDesc,
const void* alpha, /* host or device pointer */ const void* alpha, /* host or device pointer */
const void* A, const void* A,
cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Adesc,
const void* B, const void* B,
cublasLtMatrixLayout_t Bdesc, cublasLtMatrixLayout_t Bdesc,
const void* beta, /* host or device pointer */ const void* beta, /* host or device pointer */
const void* C, const void* C,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Cdesc,
void* D, void* D,
cublasLtMatrixLayout_t Ddesc, cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t& algo, const cublasLtMatmulAlgo_t& algo,
int kernelRepeats, int kernelRepeats,
void* workSpace, void* workSpace,
size_t workSpaceSizeInBytes, size_t workSpaceSizeInBytes,
customMatmulPerf_t& perfResults, customMatmulPerf_t& perfResults,
cudaStream_t stream) cudaStream_t stream)
{ {
cublasLtMatmulHeuristicResult_t heurResult; cublasLtMatmulHeuristicResult_t heurResult;
/* Looping over the Algo */ /* Looping over the Algo */
int repeats = kernelRepeats; int repeats = kernelRepeats;
cublasStatus_t algoStatus = cublasStatus_t algoStatus =
cublasLtMatmulAlgoCheck(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, &algo, &heurResult); cublasLtMatmulAlgoCheck(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, &algo, &heurResult);
if (algoStatus == CUBLAS_STATUS_SUCCESS) { if (algoStatus == CUBLAS_STATUS_SUCCESS) {
if (heurResult.workspaceSize <= workSpaceSizeInBytes) { if (heurResult.workspaceSize <= workSpaceSizeInBytes) {
struct timeval start, end; struct timeval start, end;
cublasStatus_t oneRunStatus; cublasStatus_t oneRunStatus;
cudaDeviceSynchronize(); cudaDeviceSynchronize();
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int loop = 0; loop < repeats; loop++) { for (int loop = 0; loop < repeats; loop++) {
oneRunStatus = cublasLtMatmul(ltHandle, oneRunStatus = cublasLtMatmul(ltHandle,
operationDesc, operationDesc,
alpha, alpha,
A, A,
Adesc, Adesc,
B, B,
Bdesc, Bdesc,
beta, beta,
C, C,
Cdesc, Cdesc,
D, D,
Ddesc, Ddesc,
&algo, &algo,
workSpace, workSpace,
workSpaceSizeInBytes, workSpaceSizeInBytes,
stream); stream);
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
gettimeofday(&end, NULL); gettimeofday(&end, NULL);
if (oneRunStatus != CUBLAS_STATUS_SUCCESS) { if (oneRunStatus != CUBLAS_STATUS_SUCCESS) {
algoStatus = oneRunStatus; algoStatus = oneRunStatus;
} }
float time = diffTime(start, end); float time = diffTime(start, end);
// For the moment only add successful findings // For the moment only add successful findings
if (algoStatus == CUBLAS_STATUS_SUCCESS) { if (algoStatus == CUBLAS_STATUS_SUCCESS) {
perfResults.algo = algo; perfResults.algo = algo;
perfResults.time = time / repeats; perfResults.time = time / repeats;
perfResults.workspaceSize = heurResult.workspaceSize; perfResults.workspaceSize = heurResult.workspaceSize;
perfResults.wavesCount = heurResult.wavesCount; perfResults.wavesCount = heurResult.wavesCount;
} }
} }
else { else {
// printf("not enough workspace! %ld\n", heurResult.workspaceSize); // printf("not enough workspace! %ld\n", heurResult.workspaceSize);
algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
} }
} }
else { else {
// printf("check fail!\n"); // printf("check fail!\n");
} }
return algoStatus; return algoStatus;
} }
// Sample wrapper running through multiple algo and config attributes combination for INT8 gemm using cublasLt low-level // Sample wrapper running through multiple algo and config attributes combination for INT8 gemm using cublasLt low-level
// API // API
template<typename T, typename scaleT> template<typename T, typename scaleT>
int LtIgemmCustomFind(cublasLtHandle_t ltHandle, int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int m, int m,
int n, int n,
int k, int k,
const scaleT* alpha, /* host pointer */ const scaleT* alpha, /* host pointer */
const int8_t* A, const int8_t* A,
const int8_t* B, const int8_t* B,
const scaleT* beta, /* host pointer */ const scaleT* beta, /* host pointer */
T* C, T* C,
void* workSpace, void* workSpace,
size_t workSpaceSize, size_t workSpaceSize,
FILE* fout) FILE* fout)
{ {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaStream_t stream = 0; cudaStream_t stream = 0;
// SplitK value that we are going to try when SplitK is supported for a given algo // SplitK value that we are going to try when SplitK is supported for a given algo
const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32};
// Let try a fixed number of combinations // Let try a fixed number of combinations
#define ALGO_COMBINATIONS 50000 #define ALGO_COMBINATIONS 50000
int AlgoCombinations = ALGO_COMBINATIONS; int AlgoCombinations = ALGO_COMBINATIONS;
int AlgoCount = 0; int AlgoCount = 0;
int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
int nbAlgoIds = 0; int nbAlgoIds = 0;
#define ALGO_IDS 100 #define ALGO_IDS 100
int algoIdA[ALGO_IDS]; int algoIdA[ALGO_IDS];
cudaDataType_t Atype, Btype, Ctype, scaleType; cudaDataType_t Atype, Btype, Ctype, scaleType;
Atype = CUDA_R_8I; Atype = CUDA_R_8I;
Btype = CUDA_R_8I; Btype = CUDA_R_8I;
if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) { if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) {
Ctype = CUDA_R_32I; Ctype = CUDA_R_32I;
scaleType = CUDA_R_32I; scaleType = CUDA_R_32I;
} }
else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) { else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) {
Ctype = CUDA_R_8I; Ctype = CUDA_R_8I;
scaleType = CUDA_R_32F; scaleType = CUDA_R_32F;
} }
else { else {
printf("[ERROR]<T,scaleT> of igemm is invalid\n"); printf("[ERROR]<T,scaleT> of igemm is invalid\n");
exit(-1); exit(-1);
} }
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else #else
cudaDataType_t computeType = CUDA_R_32I; cudaDataType_t computeType = CUDA_R_32I;
#endif #endif
cublasOperation_t opTranspose = CUBLAS_OP_T; cublasOperation_t opTranspose = CUBLAS_OP_T;
bool use_ORDER_COL32_2R_4R4 = false; bool use_ORDER_COL32_2R_4R4 = false;
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
int device{-1}; int device{-1};
cudaGetDevice(&device); cudaGetDevice(&device);
cudaDeviceProp props; cudaDeviceProp props;
cudaGetDeviceProperties(&props, device); cudaGetDeviceProperties(&props, device);
if (props.major * 10 + props.minor >= 80) { if (props.major * 10 + props.minor >= 80) {
use_ORDER_COL32_2R_4R4 = true; use_ORDER_COL32_2R_4R4 = true;
} }
#endif #endif
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4) { if (use_ORDER_COL32_2R_4R4) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} }
else { else {
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} }
#else #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif #endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
int ldbTransform; int ldbTransform;
if (use_ORDER_COL32_2R_4R4) { if (use_ORDER_COL32_2R_4R4) {
ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; ldbTransform = 32 * ((n + 32 - 1) / 32) * 32;
} }
else { else {
ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; ldbTransform = 32 * ((n + 8 - 1) / 8) * 8;
} }
int ldcTransform = 32 * m; int ldcTransform = 32 * m;
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else #else
status = cublasLtMatmulDescCreate(&operationDesc, scaleType); status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif #endif
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
// Create matrix descriptors. // Create matrix descriptors.
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform); status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform); status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
status = status =
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB)); cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB));
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform); status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
// Request AlgoId available for IGEMM // Request AlgoId available for IGEMM
status = cublasLtMatmulAlgoGetIds( status = cublasLtMatmulAlgoGetIds(
ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds); ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
// Loop over the Algo IDs // Loop over the Algo IDs
for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) {
cublasLtMatmulAlgo_t algo; cublasLtMatmulAlgo_t algo;
size_t sizeWritten = 0; size_t sizeWritten = 0;
/* Initialize algo structure with given Algp ID */ /* Initialize algo structure with given Algp ID */
status = status =
cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo); cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
continue; continue;
} }
// Query the tiles enums supported by that algo // Query the tiles enums supported by that algo
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten);
int nbTiles = int(sizeWritten / sizeof(int)); int nbTiles = int(sizeWritten / sizeof(int));
int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; int* tileA = new int[nbTiles == 0 ? 1 : nbTiles];
if (nbTiles == 0) { if (nbTiles == 0) {
tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED;
nbTiles = 1; nbTiles = 1;
} }
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten);
int nbStages = int(sizeWritten / sizeof(int)); int nbStages = int(sizeWritten / sizeof(int));
std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages); std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages);
if (nbStages == 0) { if (nbStages == 0) {
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages = 1; nbStages = 1;
} }
else { else {
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten); &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten);
} }
#endif #endif
int splitkSupport, redMask, swizzlingMax, customOptionMax; int splitkSupport, redMask, swizzlingMax, customOptionMax;
// Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations // Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten); &algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten); &algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten); &algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten); &algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten); &algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten);
/* Loop over the different tiles */ /* Loop over the different tiles */
for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) { for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) {
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
/* Loop over different stages count */ /* Loop over different stages count */
for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) { for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) {
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx])); &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]));
#endif #endif
/* Loop over the different custom option if any */ /* Loop over the different custom option if any */
for (int customOption = 0; customOption <= customOptionMax; customOption++) { for (int customOption = 0; customOption <= customOptionMax; customOption++) {
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption)); &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption));
/* Loop over the CTAs swizzling support */ /* Loop over the CTAs swizzling support */
for (int k = 0; k <= swizzlingMax; k++) { for (int k = 0; k <= swizzlingMax; k++) {
int splitK_trial = 0; int splitK_trial = 0;
if (splitkSupport) { if (splitkSupport) {
splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
} }
// Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case // Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case
// where splitK is not enabled // where splitK is not enabled
for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) {
/* Setup attribute of the algo to run */ /* Setup attribute of the algo to run */
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]));
int splitK_val = 0; int splitK_val = 0;
int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int));
if (l > 0) { // Split-K case if (l > 0) { // Split-K case
splitK_val = splitKSequenceA[l - 1]; splitK_val = splitKSequenceA[l - 1];
cublasLtMatmulAlgoConfigSetAttribute(&algo, cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM, CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&splitKSequenceA[l - 1], &splitKSequenceA[l - 1],
sizeof(splitKSequenceA[l - 1])); sizeof(splitKSequenceA[l - 1]));
/* Going over all the reduction scheme */ /* Going over all the reduction scheme */
for (redScheme = 1; for (redScheme = 1;
redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations); redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations);
redScheme = redScheme << 1) { redScheme = redScheme << 1) {
if (redScheme & redMask) { if (redScheme & redMask) {
cublasLtMatmulAlgoConfigSetAttribute(&algo, cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&redScheme, &redScheme,
sizeof(redScheme)); sizeof(redScheme));
status = customMatmulRun(ltHandle, status = customMatmulRun(ltHandle,
operationDesc, operationDesc,
alpha, /* host or device pointer */ alpha, /* host or device pointer */
A, A,
Adesc, Adesc,
B, B,
Bdesc, Bdesc,
beta, /* host or device pointer */ beta, /* host or device pointer */
C, C,
Cdesc, Cdesc,
C, C,
Cdesc, Cdesc,
algo, algo,
kernelRepeats, kernelRepeats,
workSpace, workSpace,
workSpaceSize, workSpaceSize,
perfResults[AlgoCount], perfResults[AlgoCount],
stream); stream);
perfResults[AlgoCount].status = status; perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; AlgoCount++;
} }
} // end if } // end if
} // end for } // end for
} }
else { // Non-splitK case else { // Non-splitK case
/* if user preference is ok with workspace */ /* if user preference is ok with workspace */
if (AlgoCount < AlgoCombinations) { if (AlgoCount < AlgoCombinations) {
status = customMatmulRun(ltHandle, status = customMatmulRun(ltHandle,
operationDesc, operationDesc,
alpha, /* host or device pointer */ alpha, /* host or device pointer */
A, A,
Adesc, Adesc,
B, B,
Bdesc, Bdesc,
beta, /* host or device pointer */ beta, /* host or device pointer */
C, C,
Cdesc, Cdesc,
C, C,
Cdesc, Cdesc,
algo, algo,
kernelRepeats, kernelRepeats,
workSpace, workSpace,
workSpaceSize, workSpaceSize,
perfResults[AlgoCount], perfResults[AlgoCount],
stream); stream);
perfResults[AlgoCount].status = status; perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; AlgoCount++;
} }
} }
} }
} // end l } // end l
} // end k } // end k
} // end customOption } // end customOption
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
} // end stagesIdx } // end stagesIdx
#endif #endif
} // end tileIdx } // end tileIdx
delete[] tileA; delete[] tileA;
} // end idx } // end idx
// Sort the results per run duration // Sort the results per run duration
std::sort(perfResults, perfResults + AlgoCount, time_compare); std::sort(perfResults, perfResults + AlgoCount, time_compare);
// Print timing and perf details // Print timing and perf details
for (int i = 0, hasPrint = 0; i < AlgoCount; i++) { for (int i = 0, hasPrint = 0; i < AlgoCount; i++) {
printf("result %03d : ", i); printf("result %03d : ", i);
hasPrint = printPerfStructure(m, n, k, perfResults[i], fout, hasPrint); hasPrint = printPerfStructure(m, n, k, perfResults[i], fout, hasPrint);
} }
CLEANUP: CLEANUP:
// Descriptors are no longer needed as all GPU work was already enqueued // Descriptors are no longer needed as all GPU work was already enqueued
if (Cdesc) { if (Cdesc) {
cublasLtMatrixLayoutDestroy(Cdesc); cublasLtMatrixLayoutDestroy(Cdesc);
} }
if (Bdesc) { if (Bdesc) {
cublasLtMatrixLayoutDestroy(Bdesc); cublasLtMatrixLayoutDestroy(Bdesc);
} }
if (Adesc) { if (Adesc) {
cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatrixLayoutDestroy(Adesc);
} }
if (operationDesc) { if (operationDesc) {
cublasLtMatmulDescDestroy(operationDesc); cublasLtMatmulDescDestroy(operationDesc);
} }
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
} }
template int LtIgemmCustomFind(cublasLtHandle_t ltHandle, template int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int m, int m,
int n, int n,
int k, int k,
const int* alpha, /* host pointer */ const int* alpha, /* host pointer */
const int8_t* A, const int8_t* A,
const int8_t* B, const int8_t* B,
const int* beta, /* host pointer */ const int* beta, /* host pointer */
int32_t* C, int32_t* C,
void* workSpace, void* workSpace,
size_t workSpaceSize, size_t workSpaceSize,
FILE* fout); FILE* fout);
template int LtIgemmCustomFind(cublasLtHandle_t ltHandle, template int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int m, int m,
int n, int n,
int k, int k,
const float* alpha, /* host pointer */ const float* alpha, /* host pointer */
const int8_t* A, const int8_t* A,
const int8_t* B, const int8_t* B,
const float* beta, /* host pointer */ const float* beta, /* host pointer */
int8_t* C, int8_t* C,
void* workSpace, void* workSpace,
size_t workSpaceSize, size_t workSpaceSize,
FILE* fout); FILE* fout);
template<typename T, typename scaleT> template<typename T, typename scaleT>
int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int batchCount, int batchCount,
int m, int m,
int n, int n,
int k, int k,
const scaleT* alpha, /* host pointer */ const scaleT* alpha, /* host pointer */
const int8_t* A, const int8_t* A,
const int8_t* B, const int8_t* B,
const scaleT* beta, /* host pointer */ const scaleT* beta, /* host pointer */
T* C, T* C,
void* workSpace, void* workSpace,
size_t workSpaceSize, size_t workSpaceSize,
FILE* fout) FILE* fout)
{ {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaStream_t stream = 0; cudaStream_t stream = 0;
// SplitK value that we are going to try when SplitK is supported for a given algo // SplitK value that we are going to try when SplitK is supported for a given algo
const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32};
// Let try a fixed number of combinations // Let try a fixed number of combinations
#define ALGO_COMBINATIONS 50000 #define ALGO_COMBINATIONS 50000
int AlgoCombinations = ALGO_COMBINATIONS; int AlgoCombinations = ALGO_COMBINATIONS;
int AlgoCount = 0; int AlgoCount = 0;
int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
int nbAlgoIds = 0; int nbAlgoIds = 0;
#define ALGO_IDS 100 #define ALGO_IDS 100
int algoIdA[ALGO_IDS]; int algoIdA[ALGO_IDS];
cudaDataType_t Atype, Btype, Ctype, scaleType; cudaDataType_t Atype, Btype, Ctype, scaleType;
Atype = CUDA_R_8I; Atype = CUDA_R_8I;
Btype = CUDA_R_8I; Btype = CUDA_R_8I;
if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) { if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) {
Ctype = CUDA_R_32I; Ctype = CUDA_R_32I;
scaleType = CUDA_R_32I; scaleType = CUDA_R_32I;
} }
else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) { else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) {
Ctype = CUDA_R_8I; Ctype = CUDA_R_8I;
scaleType = CUDA_R_32F; scaleType = CUDA_R_32F;
} }
else { else {
printf("[ERROR]<T,scaleT> of igemm is invalid\n"); printf("[ERROR]<T,scaleT> of igemm is invalid\n");
exit(-1); exit(-1);
} }
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else #else
cudaDataType_t computeType = CUDA_R_32I; cudaDataType_t computeType = CUDA_R_32I;
#endif #endif
cublasOperation_t opTranspose = CUBLAS_OP_T; cublasOperation_t opTranspose = CUBLAS_OP_T;
bool use_ORDER_COL32_2R_4R4 = false; bool use_ORDER_COL32_2R_4R4 = false;
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
int device{-1}; int device{-1};
cudaGetDevice(&device); cudaGetDevice(&device);
cudaDeviceProp props; cudaDeviceProp props;
cudaGetDeviceProperties(&props, device); cudaGetDeviceProperties(&props, device);
if (props.major * 10 + props.minor >= 80) { if (props.major * 10 + props.minor >= 80) {
use_ORDER_COL32_2R_4R4 = true; use_ORDER_COL32_2R_4R4 = true;
} }
#endif #endif
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4) { if (use_ORDER_COL32_2R_4R4) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} }
else { else {
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} }
#else #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif #endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
int ldbTransform; int ldbTransform;
if (use_ORDER_COL32_2R_4R4) { if (use_ORDER_COL32_2R_4R4) {
ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; ldbTransform = 32 * ((n + 32 - 1) / 32) * 32;
} }
else { else {
ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; ldbTransform = 32 * ((n + 8 - 1) / 8) * 8;
} }
int ldcTransform = 32 * m; int ldcTransform = 32 * m;
int64_t stridea, strideb, stridec; int64_t stridea, strideb, stridec;
stridea = m * k; stridea = m * k;
strideb = n * k; strideb = n * k;
stridec = m * n; stridec = m * n;
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else #else
status = cublasLtMatmulDescCreate(&operationDesc, scaleType); status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif #endif
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
// Create matrix descriptors. // Create matrix descriptors.
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform); status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, sizeof(stridea)); cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, sizeof(stridea));
status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform); status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
status = status =
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB)); cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB));
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, sizeof(strideb)); cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, sizeof(strideb));
status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform); status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec)); cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec));
// Request AlgoId available for IGEMM // Request AlgoId available for IGEMM
status = cublasLtMatmulAlgoGetIds( status = cublasLtMatmulAlgoGetIds(
ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds); ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
} }
// Loop over the Algo IDs // Loop over the Algo IDs
for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) {
cublasLtMatmulAlgo_t algo; cublasLtMatmulAlgo_t algo;
size_t sizeWritten = 0; size_t sizeWritten = 0;
/* Initialize algo structure with given Algp ID */ /* Initialize algo structure with given Algp ID */
status = status =
cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo); cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
continue; continue;
} }
// Query the tiles enums supported by that algo // Query the tiles enums supported by that algo
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten);
int nbTiles = int(sizeWritten / sizeof(int)); int nbTiles = int(sizeWritten / sizeof(int));
int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; int* tileA = new int[nbTiles == 0 ? 1 : nbTiles];
if (nbTiles == 0) { if (nbTiles == 0) {
tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED;
nbTiles = 1; nbTiles = 1;
} }
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten);
int nbStages = int(sizeWritten / sizeof(int)); int nbStages = int(sizeWritten / sizeof(int));
std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages); std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages);
if (nbStages == 0) { if (nbStages == 0) {
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages = 1; nbStages = 1;
} }
else { else {
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten); &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten);
} }
#endif #endif
int splitkSupport, redMask, swizzlingMax, customOptionMax; int splitkSupport, redMask, swizzlingMax, customOptionMax;
// Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations // Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten); &algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten); &algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten); &algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten); &algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten); &algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten);
/* Loop over the different tiles */ /* Loop over the different tiles */
for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) { for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) {
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
/* Loop over different stages count */ /* Loop over different stages count */
for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) { for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) {
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx])); &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]));
#endif #endif
/* Loop over the different custom option if any */ /* Loop over the different custom option if any */
for (int customOption = 0; customOption <= customOptionMax; customOption++) { for (int customOption = 0; customOption <= customOptionMax; customOption++) {
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption)); &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption));
/* Loop over the CTAs swizzling support */ /* Loop over the CTAs swizzling support */
for (int k = 0; k <= swizzlingMax; k++) { for (int k = 0; k <= swizzlingMax; k++) {
int splitK_trial = 0; int splitK_trial = 0;
if (splitkSupport) { if (splitkSupport) {
splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
} }
// Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case // Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case
// where splitK is not enabled // where splitK is not enabled
for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) {
/* Setup attribute of the algo to run */ /* Setup attribute of the algo to run */
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]));
int splitK_val = 0; int splitK_val = 0;
int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int));
if (l > 0) { // Split-K case if (l > 0) { // Split-K case
splitK_val = splitKSequenceA[l - 1]; splitK_val = splitKSequenceA[l - 1];
cublasLtMatmulAlgoConfigSetAttribute(&algo, cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM, CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&splitKSequenceA[l - 1], &splitKSequenceA[l - 1],
sizeof(splitKSequenceA[l - 1])); sizeof(splitKSequenceA[l - 1]));
/* Going over all the reduction scheme */ /* Going over all the reduction scheme */
for (redScheme = 1; for (redScheme = 1;
redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations); redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations);
redScheme = redScheme << 1) { redScheme = redScheme << 1) {
if (redScheme & redMask) { if (redScheme & redMask) {
cublasLtMatmulAlgoConfigSetAttribute(&algo, cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&redScheme, &redScheme,
sizeof(redScheme)); sizeof(redScheme));
status = customMatmulRun(ltHandle, status = customMatmulRun(ltHandle,
operationDesc, operationDesc,
alpha, /* host or device pointer */ alpha, /* host or device pointer */
A, A,
Adesc, Adesc,
B, B,
Bdesc, Bdesc,
beta, /* host or device pointer */ beta, /* host or device pointer */
C, C,
Cdesc, Cdesc,
C, C,
Cdesc, Cdesc,
algo, algo,
kernelRepeats, kernelRepeats,
workSpace, workSpace,
workSpaceSize, workSpaceSize,
perfResults[AlgoCount], perfResults[AlgoCount],
stream); stream);
perfResults[AlgoCount].status = status; perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; AlgoCount++;
} }
} // end if } // end if
} // end for } // end for
} }
else { // Non-splitK case else { // Non-splitK case
/* if user preference is ok with workspace */ /* if user preference is ok with workspace */
if (AlgoCount < AlgoCombinations) { if (AlgoCount < AlgoCombinations) {
status = customMatmulRun(ltHandle, status = customMatmulRun(ltHandle,
operationDesc, operationDesc,
alpha, /* host or device pointer */ alpha, /* host or device pointer */
A, A,
Adesc, Adesc,
B, B,
Bdesc, Bdesc,
beta, /* host or device pointer */ beta, /* host or device pointer */
C, C,
Cdesc, Cdesc,
C, C,
Cdesc, Cdesc,
algo, algo,
kernelRepeats, kernelRepeats,
workSpace, workSpace,
workSpaceSize, workSpaceSize,
perfResults[AlgoCount], perfResults[AlgoCount],
stream); stream);
perfResults[AlgoCount].status = status; perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; AlgoCount++;
} }
} }
} }
} // end l } // end l
} // end k } // end k
} // end customOption } // end customOption
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
} // end stagesIdx } // end stagesIdx
#endif #endif
} // end tileIdx } // end tileIdx
delete[] tileA; delete[] tileA;
} // end idx } // end idx
// Sort the results per run duration // Sort the results per run duration
std::sort(perfResults, perfResults + AlgoCount, time_compare); std::sort(perfResults, perfResults + AlgoCount, time_compare);
// Print timing and perf details // Print timing and perf details
for (int i = 0, hasPrint = 0; i < AlgoCount; i++) { for (int i = 0, hasPrint = 0; i < AlgoCount; i++) {
printf("result %03d : ", i); printf("result %03d : ", i);
hasPrint = printBatchPerfStructure(batchCount, m, n, k, perfResults[i], fout, hasPrint); hasPrint = printBatchPerfStructure(batchCount, m, n, k, perfResults[i], fout, hasPrint);
} }
CLEANUP: CLEANUP:
// Descriptors are no longer needed as all GPU work was already enqueued // Descriptors are no longer needed as all GPU work was already enqueued
if (Cdesc) { if (Cdesc) {
cublasLtMatrixLayoutDestroy(Cdesc); cublasLtMatrixLayoutDestroy(Cdesc);
} }
if (Bdesc) { if (Bdesc) {
cublasLtMatrixLayoutDestroy(Bdesc); cublasLtMatrixLayoutDestroy(Bdesc);
} }
if (Adesc) { if (Adesc) {
cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatrixLayoutDestroy(Adesc);
} }
if (operationDesc) { if (operationDesc) {
cublasLtMatmulDescDestroy(operationDesc); cublasLtMatmulDescDestroy(operationDesc);
} }
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
} }
template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int batchCount, int batchCount,
int m, int m,
int n, int n,
int k, int k,
const int* alpha, /* host pointer */ const int* alpha, /* host pointer */
const int8_t* A, const int8_t* A,
const int8_t* B, const int8_t* B,
const int* beta, /* host pointer */ const int* beta, /* host pointer */
int32_t* C, int32_t* C,
void* workSpace, void* workSpace,
size_t workSpaceSize, size_t workSpaceSize,
FILE* fout); FILE* fout);
template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int batchCount, int batchCount,
int m, int m,
int n, int n,
int k, int k,
const float* alpha, /* host pointer */ const float* alpha, /* host pointer */
const int8_t* A, const int8_t* A,
const int8_t* B, const int8_t* B,
const float* beta, /* host pointer */ const float* beta, /* host pointer */
int8_t* C, int8_t* C,
void* workSpace, void* workSpace,
size_t workSpaceSize, size_t workSpaceSize,
FILE* fout); FILE* fout);
// initialize matrix in column-major // initialize matrix in column-major
void matInit(int rows, int cols, int8_t* p, int ld) void matInit(int rows, int cols, int8_t* p, int ld)
{ {
srand(time(NULL)); srand(time(NULL));
for (int c = 0; c < cols; c++) { for (int c = 0; c < cols; c++) {
for (int r = 0; r < rows; r++) { for (int r = 0; r < rows; r++) {
int index = r + c * ld; int index = r + c * ld;
p[index] = rand() % 255 - 127; p[index] = rand() % 255 - 127;
} }
} }
} }
int batch_igemm_config(int batchCount, int m, int n, int k, FILE* fout, void* buffer) int batch_igemm_config(int batchCount, int m, int n, int k, FILE* fout, void* buffer)
{ {
printf("batchCount %d m %d n %d k %d\n", batchCount, m, n, k); printf("batchCount %d m %d n %d k %d\n", batchCount, m, n, k);
int alpha = 1; int alpha = 1;
int beta = 0; int beta = 0;
int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major
int8_t* d_B = d_A + batchCount * m * k; // k * n, stored in column-major int8_t* d_B = d_A + batchCount * m * k; // k * n, stored in column-major
int32_t* d_C = (int32_t*)(d_B + batchCount * k * n); // m * n, stored in column-major int32_t* d_C = (int32_t*)(d_B + batchCount * k * n); // m * n, stored in column-major
cublasLtHandle_t ltHandle; cublasLtHandle_t ltHandle;
cublasLtCreate(&ltHandle); cublasLtCreate(&ltHandle);
LtBatchIgemmCustomFind(ltHandle, LtBatchIgemmCustomFind(ltHandle,
batchCount, batchCount,
m, m,
n, n,
k, k,
&alpha, /* host pointer */ &alpha, /* host pointer */
d_A, d_A,
d_B, d_B,
&beta, /* host pointer */ &beta, /* host pointer */
d_C, d_C,
NULL, NULL,
0, 0,
fout); fout);
// free memory // free memory
cublasLtDestroy(ltHandle); cublasLtDestroy(ltHandle);
return 0; return 0;
} }
int igemm_config(int m, int n, int k, FILE* fout, void* buffer) int igemm_config(int m, int n, int k, FILE* fout, void* buffer)
{ {
printf("batchCount %d m %d n %d k %d\n", 1, m, n, k); printf("batchCount %d m %d n %d k %d\n", 1, m, n, k);
int alpha = 1; int alpha = 1;
int beta = 0; int beta = 0;
int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major int8_t* d_A = (int8_t*)buffer; // m * k, stored in column-major
int8_t* d_B = d_A + m * k; // k * n, stored in column-major int8_t* d_B = d_A + m * k; // k * n, stored in column-major
int32_t* d_C = (int32_t*)(d_B + k * n); // m * n, stored in column-major int32_t* d_C = (int32_t*)(d_B + k * n); // m * n, stored in column-major
cublasLtHandle_t ltHandle; cublasLtHandle_t ltHandle;
cublasLtCreate(&ltHandle); cublasLtCreate(&ltHandle);
LtIgemmCustomFind(ltHandle, LtIgemmCustomFind(ltHandle,
m, m,
n, n,
k, k,
&alpha, /* host pointer */ &alpha, /* host pointer */
d_A, d_A,
d_B, d_B,
&beta, /* host pointer */ &beta, /* host pointer */
d_C, d_C,
NULL, NULL,
0, 0,
fout); fout);
cublasLtDestroy(ltHandle); cublasLtDestroy(ltHandle);
return 0; return 0;
} }
int generate_encoder_igemm_config( int generate_encoder_igemm_config(
int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend) int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend)
{ {
// ensure program running on SM >= 7.5 // ensure program running on SM >= 7.5
struct cudaDeviceProp prop; struct cudaDeviceProp prop;
check_cuda_error(cudaGetDeviceProperties(&prop, 0)); check_cuda_error(cudaGetDeviceProperties(&prop, 0));
if (!(prop.major >= 8 || (prop.major >= 7 && prop.minor >= 5))) { if (!(prop.major >= 8 || (prop.major >= 7 && prop.minor >= 5))) {
printf("[ERROR] INT8 mode > 0 is only supported on device with sm >= 7.5\n "); printf("[ERROR] INT8 mode > 0 is only supported on device with sm >= 7.5\n ");
exit(-1); exit(-1);
} }
printf("Device %s\n", prop.name); printf("Device %s\n", prop.name);
// check config // check config
FILE* fout; FILE* fout;
if (!isAppend) { if (!isAppend) {
fout = fopen(IGEMM_CONFIG, "w+"); fout = fopen(IGEMM_CONFIG, "w+");
fprintf( fprintf(
fout, fout,
"batch_size seq_len head_num size_per_head dataType ### batchCount m n k algoId customOption tile splitK_val swizzle reductionScheme workspaceSize stages exec_time\n"); "batch_size seq_len head_num size_per_head dataType ### batchCount m n k algoId customOption tile splitK_val swizzle reductionScheme workspaceSize stages exec_time\n");
} }
else { else {
fout = fopen(IGEMM_CONFIG, "a+"); fout = fopen(IGEMM_CONFIG, "a+");
std::vector<std::string> config; std::vector<std::string> config;
char line[1024]; char line[1024];
while (fgets(line, 1024, fout) != NULL) { while (fgets(line, 1024, fout) != NULL) {
config.push_back(std::string(line)); config.push_back(std::string(line));
} }
if (config.size() >= MAX_CONFIG_NUM * GEMM_NUM) { if (config.size() >= MAX_CONFIG_NUM * GEMM_NUM) {
int startIdx = config.size() - (MAX_CONFIG_NUM - 1) * GEMM_NUM; int startIdx = config.size() - (MAX_CONFIG_NUM - 1) * GEMM_NUM;
fclose(fout); fclose(fout);
fout = fopen(IGEMM_CONFIG, "w+"); fout = fopen(IGEMM_CONFIG, "w+");
for (int i = startIdx; i < (int)config.size(); i++) { for (int i = startIdx; i < (int)config.size(); i++) {
fprintf(fout, "%s", config[i].c_str()); fprintf(fout, "%s", config[i].c_str());
} }
} }
} }
batch_size_ = batch_size; batch_size_ = batch_size;
seq_len_ = seq_len; seq_len_ = seq_len;
head_num_ = head_num; head_num_ = head_num;
size_per_head_ = size_per_head; size_per_head_ = size_per_head;
int m = batch_size * seq_len; int m = batch_size * seq_len;
int n = head_num * size_per_head; int n = head_num * size_per_head;
int k = n; int k = n;
int batchCount; int batchCount;
printf("***Encoder IGemm Testing Begin***\n"); printf("***Encoder IGemm Testing Begin***\n");
printf("\n-----------------------------\n"); printf("\n-----------------------------\n");
batchCount = 3; batchCount = 3;
m = batch_size * seq_len; m = batch_size * seq_len;
k = head_num * size_per_head; k = head_num * size_per_head;
n = k; n = k;
if (n % 32 != 0 || k % 32 != 0) { if (n % 32 != 0 || k % 32 != 0) {
printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k);
} }
else { else {
batch_igemm_config(batchCount, m, n, k, fout, buffer); batch_igemm_config(batchCount, m, n, k, fout, buffer);
} }
printf("\n-----------------------------\n"); printf("\n-----------------------------\n");
m = seq_len; m = seq_len;
n = seq_len; n = seq_len;
k = size_per_head; k = size_per_head;
batchCount = batch_size * head_num; batchCount = batch_size * head_num;
if (n % 32 != 0 || k % 32 != 0) { if (n % 32 != 0 || k % 32 != 0) {
printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k);
} }
else { else {
batch_igemm_config(batchCount, m, n, k, fout, buffer); batch_igemm_config(batchCount, m, n, k, fout, buffer);
} }
printf("\n-----------------------------\n"); printf("\n-----------------------------\n");
m = seq_len; m = seq_len;
n = size_per_head; n = size_per_head;
k = seq_len; k = seq_len;
batchCount = batch_size * head_num; batchCount = batch_size * head_num;
if (n % 32 != 0 || k % 32 != 0) { if (n % 32 != 0 || k % 32 != 0) {
printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k);
} }
else { else {
batch_igemm_config(batchCount, m, n, k, fout, buffer); batch_igemm_config(batchCount, m, n, k, fout, buffer);
} }
printf("\n-----------------------------\n"); printf("\n-----------------------------\n");
m = batch_size * seq_len; m = batch_size * seq_len;
n = head_num * size_per_head; n = head_num * size_per_head;
k = head_num * size_per_head; k = head_num * size_per_head;
if (n % 32 != 0 || k % 32 != 0) { if (n % 32 != 0 || k % 32 != 0) {
printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k);
} }
else { else {
igemm_config(m, n, k, fout, buffer); igemm_config(m, n, k, fout, buffer);
} }
printf("\n-----------------------------\n"); printf("\n-----------------------------\n");
n = 4 * n; n = 4 * n;
if (n % 32 != 0 || k % 32 != 0) { if (n % 32 != 0 || k % 32 != 0) {
printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k);
} }
else { else {
igemm_config(m, n, k, fout, buffer); igemm_config(m, n, k, fout, buffer);
} }
printf("\n-----------------------------\n"); printf("\n-----------------------------\n");
n = k; n = k;
k = 4 * n; k = 4 * n;
if (n % 32 != 0 || k % 32 != 0) { if (n % 32 != 0 || k % 32 != 0) {
printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k);
} }
else { else {
igemm_config(m, n, k, fout, buffer); igemm_config(m, n, k, fout, buffer);
} }
fclose(fout); fclose(fout);
printf("\n-----------------------------\n"); printf("\n-----------------------------\n");
printf("***Encoder IGemm Testing End***\n"); printf("***Encoder IGemm Testing End***\n");
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
bool do_sparse_test = false; bool do_sparse_test = false;
if (prop.major == 8 && (prop.minor == 0 || prop.minor == 6)) { if (prop.major == 8 && (prop.minor == 0 || prop.minor == 6)) {
do_sparse_test = true; do_sparse_test = true;
} }
if (do_sparse_test) { if (do_sparse_test) {
printf("***cusparseLt Gemm Testing Begin***\n"); printf("***cusparseLt Gemm Testing Begin***\n");
const int spgemm_num = 3; const int spgemm_num = 3;
FILE* fd; FILE* fd;
int line_count = 0; int line_count = 0;
const int ites = 100; const int ites = 100;
struct timeval start, end; struct timeval start, end;
if (!isAppend) { if (!isAppend) {
fd = fopen(SPIGEMM_CONFIG, "w+"); fd = fopen(SPIGEMM_CONFIG, "w+");
} }
else { else {
fd = fopen(SPIGEMM_CONFIG, "a+"); fd = fopen(SPIGEMM_CONFIG, "a+");
std::vector<std::string> config; std::vector<std::string> config;
char line[1024]; char line[1024];
while (fgets(line, 1024, fd) != NULL) { while (fgets(line, 1024, fd) != NULL) {
config.push_back(std::string(line)); config.push_back(std::string(line));
} }
line_count = config.size(); line_count = config.size();
if (config.size() >= (MAX_CONFIG_NUM * spgemm_num + 1)) // 6 cublas/cublasLt, first row is not included if (config.size() >= (MAX_CONFIG_NUM * spgemm_num + 1)) // 6 cublas/cublasLt, first row is not included
{ {
int startIdx = config.size() - ((MAX_CONFIG_NUM - 1) * spgemm_num); int startIdx = config.size() - ((MAX_CONFIG_NUM - 1) * spgemm_num);
fclose(fd); fclose(fd);
fd = fopen(SPIGEMM_CONFIG, "w+"); fd = fopen(SPIGEMM_CONFIG, "w+");
fprintf(fd, "%s", config[0].c_str()); fprintf(fd, "%s", config[0].c_str());
for (uint i = startIdx; i < config.size(); i++) { for (uint i = startIdx; i < config.size(); i++) {
fprintf(fd, "%s", config[i].c_str()); fprintf(fd, "%s", config[i].c_str());
} }
line_count = config.size() - (spgemm_num + 3); line_count = config.size() - (spgemm_num + 3);
} }
} }
if (line_count == 0) { if (line_count == 0) {
fprintf( fprintf(
fd, fd,
"batch_size, seq_len, head_num, size_per_head dataType ### batchCount, m, n, k, algoId, exec_time\n"); "batch_size, seq_len, head_num, size_per_head dataType ### batchCount, m, n, k, algoId, exec_time\n");
} }
int M[spgemm_num]; int M[spgemm_num];
int N[spgemm_num]; int N[spgemm_num];
int K[spgemm_num]; int K[spgemm_num];
// gemm1 // gemm1
M[0] = batch_size * seq_len; M[0] = batch_size * seq_len;
K[0] = head_num * size_per_head; K[0] = head_num * size_per_head;
N[0] = K[0]; N[0] = K[0];
// gemm2 // gemm2
M[1] = M[0]; M[1] = M[0];
K[1] = K[0]; K[1] = K[0];
N[1] = 4 * N[0]; N[1] = 4 * N[0];
// gemm3 // gemm3
M[2] = M[0]; M[2] = M[0];
K[2] = 4 * K[0]; K[2] = 4 * K[0];
N[2] = N[0]; N[2] = N[0];
cusparseLtHandle_t handle; cusparseLtHandle_t handle;
CHECK_CUSPARSE(cusparseLtInit(&handle)); CHECK_CUSPARSE(cusparseLtInit(&handle));
cusparseOrder_t col_order = CUSPARSE_ORDER_COL; cusparseOrder_t col_order = CUSPARSE_ORDER_COL;
cusparseOrder_t row_order = CUSPARSE_ORDER_ROW; cusparseOrder_t row_order = CUSPARSE_ORDER_ROW;
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseComputeType compute_type = CUSPARSE_COMPUTE_32I; cusparseComputeType compute_type = CUSPARSE_COMPUTE_32I;
unsigned alignment = 16; unsigned alignment = 16;
cudaStream_t stream = 0; cudaStream_t stream = 0;
float alpha2 = 1.0f; float alpha2 = 1.0f;
float beta2 = 0.0f; float beta2 = 0.0f;
for (int i = 0; i < spgemm_num; ++i) { for (int i = 0; i < spgemm_num; ++i) {
// to be compatible with spgemm wrapper, we let A be the weight matrix // to be compatible with spgemm wrapper, we let A be the weight matrix
// so m and n are swapped // so m and n are swapped
// A: mxk B: kxn C:mxn // A: mxk B: kxn C:mxn
int m = N[i], n = M[i], k = K[i]; int m = N[i], n = M[i], k = K[i];
printf("\n-----------------------------\n"); printf("\n-----------------------------\n");
printf("GEMM test %d: [M: %d, K: %d, N: %d]\n", i, m, k, n); printf("GEMM test %d: [M: %d, K: %d, N: %d]\n", i, m, k, n);
int8_t* d_A = (int8_t*)buffer; int8_t* d_A = (int8_t*)buffer;
int8_t* d_B = d_A + m * k; int8_t* d_B = d_A + m * k;
int8_t* d_C = d_B + k * n; int8_t* d_C = d_B + k * n;
int8_t* dA_compressed; int8_t* dA_compressed;
{ {
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size; size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size)) CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size)); check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError(); cudaError_t result = cudaGetLastError();
if (result) { if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ")); throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: "));
} }
float exec_time = 99999.0f; float exec_time = 99999.0f;
int fast_algo = 0; int fast_algo = 0;
for (int alg = 0; alg < 4; ++alg) { for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr; void* d_workspace = nullptr;
int num_streams = 1; int num_streams = 1;
cudaStream_t streams[1] = {stream}; cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, k, alignment, CUDA_R_8I, row_order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_8I, col_order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_8I, col_order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_8I, col_order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_8I, col_order))
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) { for (int ite = 0; ite < ites; ++ite) {
// initializing MatDesc takes a lot of time // initializing MatDesc takes a lot of time
// and these descs can be stored to other place // and these descs can be stored to other place
// whereas storing MatMulPlan to other place will cause errors // whereas storing MatMulPlan to other place will cause errors
cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulDescriptor_t matmul;
cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan; cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, sizeof(alg))) &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, sizeof(alg)))
size_t workspace_size; size_t workspace_size;
CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&handle, &alg_sel, &workspace_size)) CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&handle, &alg_sel, &workspace_size))
CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel, workspace_size)) CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel, workspace_size))
CHECK_CUSPARSE(cusparseLtMatmul(&handle, CHECK_CUSPARSE(cusparseLtMatmul(&handle,
&plan, &plan,
&alpha2, &alpha2,
dA_compressed, dA_compressed,
d_B, d_B,
&beta2, &beta2,
d_C, d_C,
d_C, d_C,
d_workspace, d_workspace,
streams, streams,
num_streams)) num_streams))
CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan)) CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan))
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
gettimeofday(&end, NULL); gettimeofday(&end, NULL);
printf("algo_%d costs %.3fms \n", alg, diffTime(start, end) / ites); printf("algo_%d costs %.3fms \n", alg, diffTime(start, end) / ites);
if (diffTime(start, end) < exec_time) { if (diffTime(start, end) < exec_time) {
exec_time = diffTime(start, end); exec_time = diffTime(start, end);
fast_algo = alg; fast_algo = alg;
} }
} }
exec_time /= ites; exec_time /= ites;
printf("fast_algo %d\n", fast_algo); printf("fast_algo %d\n", fast_algo);
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### 1 %d %d %d %d %f\n", "%d %d %d %d %d ### 1 %d %d %d %d %f\n",
batch_size, batch_size,
seq_len, seq_len,
head_num, head_num,
size_per_head, size_per_head,
HALF_DATATYPE, HALF_DATATYPE,
m, m,
n, n,
k, k,
fast_algo, fast_algo,
exec_time); exec_time);
cudaFree(dA_compressed); cudaFree(dA_compressed);
} }
CHECK_CUSPARSE(cusparseLtDestroy(&handle)) CHECK_CUSPARSE(cusparseLtDestroy(&handle))
fclose(fd); fclose(fd);
printf("***cusparseLt Gemm Testing End***\n"); printf("***cusparseLt Gemm Testing End***\n");
} }
#endif #endif
return 0; return 0;
} }
} // namespace fastertransformer } // namespace fastertransformer
/* /*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#pragma once #pragma once
#include "src/fastertransformer/utils/cublasAlgoMap.h" #include "src/fastertransformer/utils/cublasAlgoMap.h"
#include "src/fastertransformer/utils/cuda_utils.h" #include "src/fastertransformer/utils/cuda_utils.h"
#include <algorithm> #include <algorithm>
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <map> #include <map>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <sys/time.h> #include <sys/time.h>
#include <time.h> #include <time.h>
#include <unistd.h> #include <unistd.h>
#include <vector> #include <vector>
namespace fastertransformer { namespace fastertransformer {
/* CAUTION : must match cublasLtMatmulTile_t */ /* CAUTION : must match cublasLtMatmulTile_t */
const char* const matmulTileName[] = {"UNDEF", "8x8", "8x16", "16x8", "8x32", "16x16", "32x8", const char* const matmulTileName[] = {"UNDEF", "8x8", "8x16", "16x8", "8x32", "16x16", "32x8",
"8x64", "16x32", "32x16", "64x8", "32x32", "32x64", "64x32", "8x64", "16x32", "32x16", "64x8", "32x32", "32x64", "64x32",
"32x128", "64x64", "128x32", "64x128", "128x64", "64x256", "128x128", "32x128", "64x64", "128x32", "64x128", "128x64", "64x256", "128x128",
"256x64", "64x512", "128x256", "256x128", "512x64", "64x96", "96*64", "256x64", "64x512", "128x256", "256x128", "512x64", "64x96", "96*64",
"96x128", "128x160", "160x128", "192x128", "128x192", "128x96", "END"}; "96x128", "128x160", "160x128", "192x128", "128x192", "128x96", "END"};
int generate_encoder_igemm_config( int generate_encoder_igemm_config(
int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true); int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true);
int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE* fout, int hasPrint); int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE* fout, int hasPrint);
int printBatchPerfStructure( int printBatchPerfStructure(
int batchCount, int m, int n, int k, const customMatmulPerf_t& perf, FILE* fout, int hasPrint); int batchCount, int m, int n, int k, const customMatmulPerf_t& perf, FILE* fout, int hasPrint);
template<typename T, typename scaleT> template<typename T, typename scaleT>
int LtIgemmCustomFind(cublasLtHandle_t ltHandle, int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int m, int m,
int n, int n,
int k, int k,
const scaleT* alpha, /* host pointer */ const scaleT* alpha, /* host pointer */
const int8_t* A, const int8_t* A,
const int8_t* B, const int8_t* B,
const scaleT* beta, /* host pointer */ const scaleT* beta, /* host pointer */
T* C, T* C,
void* workSpace, void* workSpace,
size_t workSpaceSize, size_t workSpaceSize,
FILE* fout); FILE* fout);
template<typename T, typename scaleT> template<typename T, typename scaleT>
int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int batchCount, int batchCount,
int m, int m,
int n, int n,
int k, int k,
const scaleT* alpha, /* host pointer */ const scaleT* alpha, /* host pointer */
const int8_t* A, const int8_t* A,
const int8_t* B, const int8_t* B,
const scaleT* beta, /* host pointer */ const scaleT* beta, /* host pointer */
T* C, T* C,
void* workSpace, void* workSpace,
size_t workSpaceSize, size_t workSpaceSize,
FILE* fout); FILE* fout);
void matInit(int rows, int cols, int8_t* p, int ld); void matInit(int rows, int cols, int8_t* p, int ld);
} // namespace fastertransformer } // namespace fastertransformer
...@@ -617,15 +617,15 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -617,15 +617,15 @@ void generate_gpt_gemm_config(int batch_size,
T* d_C = d_B + k * n * batchCount[i]; T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed; T* dA_compressed;
{ {
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size; size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size)) CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size)); check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
} }
float exec_time = 99999.0f; float exec_time = 99999.0f;
...@@ -633,14 +633,15 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -633,14 +633,15 @@ void generate_gpt_gemm_config(int batch_size,
if (isSparseGemmAvailable(m, n, k)) { if (isSparseGemmAvailable(m, n, k)) {
for (int alg = 0; alg < 4; ++alg) { for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr; void* d_workspace = nullptr;
int num_streams = 1; int num_streams = 1;
cudaStream_t streams[1] = {stream}; cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(
cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
cudaDeviceSynchronize(); cudaDeviceSynchronize();
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) { for (int ite = 0; ite < ites; ++ite) {
...@@ -651,7 +652,7 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -651,7 +652,7 @@ void generate_gpt_gemm_config(int batch_size,
cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan; cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
...@@ -616,15 +616,15 @@ void generate_t5_gemm_config(int batch_size, ...@@ -616,15 +616,15 @@ void generate_t5_gemm_config(int batch_size,
T* d_C = d_B + k * n * batchCount[i]; T* d_C = d_B + k * n * batchCount[i];
T* dA_compressed; T* dA_compressed;
{ {
cusparseLtMatDescriptor_t matA; cusparseLtMatDescriptor_t mat_A;
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtSpMMAPrune2(&handle, &matA, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
size_t compressed_size; size_t compressed_size;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &matA, &compressed_size)) CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size)); check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &matA, true, opA, d_A, dA_compressed, stream)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
} }
float exec_time = 99999.0f; float exec_time = 99999.0f;
...@@ -632,14 +632,15 @@ void generate_t5_gemm_config(int batch_size, ...@@ -632,14 +632,15 @@ void generate_t5_gemm_config(int batch_size,
if (isSparseGemmAvailable(m, n, k)) { if (isSparseGemmAvailable(m, n, k)) {
for (int alg = 0; alg < 4; ++alg) { for (int alg = 0; alg < 4; ++alg) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cusparseLtMatDescriptor_t matA, matB, matC; cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
void* d_workspace = nullptr; void* d_workspace = nullptr;
int num_streams = 1; int num_streams = 1;
cudaStream_t streams[1] = {stream}; cudaStream_t streams[1] = {stream};
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matB, k, n, k, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &matC, m, n, m, alignment, CUDA_R_16F, order)) CHECK_CUSPARSE(
cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
cudaDeviceSynchronize(); cudaDeviceSynchronize();
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
for (int ite = 0; ite < ites; ++ite) { for (int ite = 0; ite < ites; ++ite) {
...@@ -650,7 +651,7 @@ void generate_t5_gemm_config(int batch_size, ...@@ -650,7 +651,7 @@ void generate_t5_gemm_config(int batch_size,
cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan; cusparseLtMatmulPlan_t plan;
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB, &matA, &matB, &matC, &matC, compute_type)) &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
CHECK_CUSPARSE( CHECK_CUSPARSE(
cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
......
...@@ -13,4 +13,4 @@ public: ...@@ -13,4 +13,4 @@ public:
virtual void* getSharedObject() = 0; virtual void* getSharedObject() = 0;
}; };
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -27,8 +27,7 @@ namespace fastertransformer { ...@@ -27,8 +27,7 @@ namespace fastertransformer {
class Logger { class Logger {
public: public:
enum Level enum Level {
{
TRACE = 0, TRACE = 0,
DEBUG = 10, DEBUG = 10,
INFO = 20, INFO = 20,
...@@ -41,7 +40,7 @@ public: ...@@ -41,7 +40,7 @@ public:
thread_local Logger instance; thread_local Logger instance;
return instance; return instance;
} }
Logger(Logger const&) = delete; Logger(Logger const&) = delete;
void operator=(Logger const&) = delete; void operator=(Logger const&) = delete;
template<typename... Args> template<typename... Args>
......
...@@ -26,4 +26,4 @@ if (TORCH_VERSION VERSION_GREATER_EQUAL "1.9.0") ...@@ -26,4 +26,4 @@ if (TORCH_VERSION VERSION_GREATER_EQUAL "1.9.0")
target_link_libraries(${LIB_NAME} "${TORCH_LIBRARIES}" fpA_intB_gemm logger) target_link_libraries(${LIB_NAME} "${TORCH_LIBRARIES}" fpA_intB_gemm logger)
else() else()
message("TORCH_VERSION ${TORCH_VERSION} < 1.9.0, skipping compiling th_moe_ops.cc because QUInt4x2 is supported after torch 1.9.0") message("TORCH_VERSION ${TORCH_VERSION} < 1.9.0, skipping compiling th_moe_ops.cc because QUInt4x2 is supported after torch 1.9.0")
endif() endif()
\ No newline at end of file
...@@ -369,4 +369,4 @@ TORCH_LIBRARY(gemm_dq_unit_ops, m) ...@@ -369,4 +369,4 @@ TORCH_LIBRARY(gemm_dq_unit_ops, m)
m.def("benchmark_against_cublas_fp", benchmark_against_cublas_fp); m.def("benchmark_against_cublas_fp", benchmark_against_cublas_fp);
m.def("fused_gemm_dq_bias_act", fused_gemm_dq_bias_act); m.def("fused_gemm_dq_bias_act", fused_gemm_dq_bias_act);
} }
} // namespace torch_ext } // namespace torch_ext
\ No newline at end of file
import torch # flake8: noqa
import unittest import unittest
import torch
def random_tensor(shape, dtype, device, mean=0, std=1): def random_tensor(shape, dtype, device, mean=0, std=1):
return torch.empty(shape, dtype=dtype, device=device).normal_(mean, std) return torch.empty(shape, dtype=dtype, device=device).normal_(mean, std)
class TestGemmDequantize(unittest.TestCase): class TestGemmDequantize(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
torch.classes.load_library("lib/libth_transformer.so") torch.classes.load_library('lib/libth_transformer.so')
torch.classes.load_library("lib/libgemm_dq_unit_ops.so") torch.classes.load_library('lib/libgemm_dq_unit_ops.so')
self.unpack_packed_int4s = torch.ops.fastertransformer.unpack_int4_packed_tensor_to_int8 self.unpack_packed_int4s = torch.ops.fastertransformer.unpack_int4_packed_tensor_to_int8
self.pack_int4s = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4 self.pack_int4s = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
self.fused_gemm_dq = torch.ops.gemm_dq_unit_ops.fused_gemm_dq self.fused_gemm_dq = torch.ops.gemm_dq_unit_ops.fused_gemm_dq
...@@ -20,187 +25,259 @@ class TestGemmDequantize(unittest.TestCase): ...@@ -20,187 +25,259 @@ class TestGemmDequantize(unittest.TestCase):
torch.manual_seed(734876213) torch.manual_seed(734876213)
def dequantize_test_helper(self, weight_type, quant_type): def dequantize_test_helper(self, weight_type, quant_type):
assert quant_type == torch.int8 or quant_type == torch.quint4x2 assert quant_type == torch.int8 or quant_type == torch.quint4x2
lower_bound = -128 if quant_type == torch.int8 else -8 lower_bound = -128 if quant_type == torch.int8 else -8
upper_bound = 127 if quant_type == torch.int8 else 7 upper_bound = 127 if quant_type == torch.int8 else 7
m, n, k = 64, 128, 64 m, n, k = 64, 128, 64
weights = torch.randint(lower_bound, upper_bound, [k, n], dtype=torch.int8, device="cpu") weights = torch.randint(lower_bound,
upper_bound, [k, n],
dtype=torch.int8,
device='cpu')
packed_weight = self.pack_int4s(weights) if quant_type == torch.quint4x2 else weights packed_weight = self.pack_int4s(
cuda_weights = self.preprocess_weights_for_mixed_gemm(packed_weight, quant_type).to("cuda") weights) if quant_type == torch.quint4x2 else weights
weights = weights.to("cuda") cuda_weights = self.preprocess_weights_for_mixed_gemm(
packed_weight, quant_type).to('cuda')
weights = weights.to('cuda')
act = torch.eye(m, dtype=weight_type, device="cuda") act = torch.eye(m, dtype=weight_type, device='cuda')
scales = torch.ones([n], dtype=weight_type, device='cuda') scales = torch.ones([n], dtype=weight_type, device='cuda')
actual = self.fused_gemm_dq(act, cuda_weights, scales) actual = self.fused_gemm_dq(act, cuda_weights, scales)
torch.testing.assert_close(actual, weights, atol=0, rtol=0, check_dtype=False) torch.testing.assert_close(actual,
weights,
atol=0,
rtol=0,
check_dtype=False)
def test_fp16_int8_dequantize(self): def test_fp16_int8_dequantize(self):
self.dequantize_test_helper(torch.float16, torch.int8) self.dequantize_test_helper(torch.float16, torch.int8)
def test_bf16_int8_dequantize(self): def test_bf16_int8_dequantize(self):
self.dequantize_test_helper(torch.bfloat16, torch.int8) self.dequantize_test_helper(torch.bfloat16, torch.int8)
def test_fp16_int4_dequantize(self): def test_fp16_int4_dequantize(self):
self.dequantize_test_helper(torch.float16, torch.quint4x2) self.dequantize_test_helper(torch.float16, torch.quint4x2)
def test_bf16_int4_dequantize(self): def test_bf16_int4_dequantize(self):
self.dequantize_test_helper(torch.bfloat16, torch.quint4x2) self.dequantize_test_helper(torch.bfloat16, torch.quint4x2)
def apply_act(self, inp, act_str): def apply_act(self, inp, act_str):
if act_str == "identity": if act_str == 'identity':
return inp return inp
elif act_str == "silu": elif act_str == 'silu':
return torch.nn.SiLU()(inp) return torch.nn.SiLU()(inp)
elif act_str == "relu": elif act_str == 'relu':
return torch.nn.ReLU()(inp) return torch.nn.ReLU()(inp)
elif act_str == "gelu": elif act_str == 'gelu':
return torch.nn.GELU(approximate="tanh")(inp) return torch.nn.GELU(approximate='tanh')(inp)
else: else:
assert False, "Unsupported activation" assert False, 'Unsupported activation'
def gemm_dequant_test_helper(self, compute_type, weight_dtype, gemm_ms, gemm_ns, gemm_ks, rtol, atol, act_str="only_gemm", benchmark=False): def gemm_dequant_test_helper(self,
assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, "Weight must be quantized" compute_type,
weight_dtype,
gemm_ms,
gemm_ns,
gemm_ks,
rtol,
atol,
act_str='only_gemm',
benchmark=False):
assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, 'Weight must be quantized'
for gemm_k in gemm_ks: for gemm_k in gemm_ks:
for gemm_n in gemm_ns: for gemm_n in gemm_ns:
torch_weights_cpu = random_tensor((gemm_k, gemm_n), dtype=compute_type, device="cpu", mean=0, std=0.002) torch_weights_cpu = random_tensor((gemm_k, gemm_n),
ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer(torch_weights_cpu, weight_dtype) dtype=compute_type,
ref_torch_weights = self.unpack_packed_int4s(ref_torch_weights) if weight_dtype == torch.quint4x2 else ref_torch_weights device='cpu',
ref_torch_weights = ref_torch_weights.to("cuda") mean=0,
processed_torch_weights = processed_torch_weights.to("cuda") std=0.002)
torch_weight_scales = torch_weight_scales.to("cuda") ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer(
torch_biases = random_tensor((gemm_n), dtype=compute_type, device="cuda", mean=0, std=0.1) torch_weights_cpu, weight_dtype)
ref_torch_weights = self.unpack_packed_int4s(
ref_torch_weights
) if weight_dtype == torch.quint4x2 else ref_torch_weights
ref_torch_weights = ref_torch_weights.to('cuda')
processed_torch_weights = processed_torch_weights.to('cuda')
torch_weight_scales = torch_weight_scales.to('cuda')
torch_biases = random_tensor((gemm_n),
dtype=compute_type,
device='cuda',
mean=0,
std=0.1)
for num_rows in gemm_ms: for num_rows in gemm_ms:
torch_activations = torch.randn(size=(num_rows, gemm_k), dtype=compute_type, device="cuda") torch_activations = torch.randn(size=(num_rows, gemm_k),
dtype=compute_type,
device='cuda')
scales_unsqueezed = torch_weight_scales.unsqueeze(0) scales_unsqueezed = torch_weight_scales.unsqueeze(0)
casted_weights = ref_torch_weights.to(torch_activations.dtype) casted_weights = ref_torch_weights.to(
dequantized_weights = torch.multiply(casted_weights, scales_unsqueezed) torch_activations.dtype)
dequantized_weights = torch.multiply(
casted_weights, scales_unsqueezed)
if benchmark: if benchmark:
assert act_str == "only_gemm", "Benchmarks against cublas must use just GEMM." assert act_str == 'only_gemm', 'Benchmarks against cublas must use just GEMM.'
torch.cuda.profiler.start() torch.cuda.profiler.start()
times, results = self.bench(torch_activations, processed_torch_weights, torch_weight_scales, dequantized_weights, 200) times, results = self.bench(torch_activations,
torch.cuda.profiler.stop() processed_torch_weights,
times = times[0] torch_weight_scales,
cublas_time = times[0].item() dequantized_weights, 200)
ft_time = times[1].item() torch.cuda.profiler.stop()
ft_speedup = cublas_time / ft_time times = times[0]
print("{},{},{},{},{},{}".format(num_rows, gemm_n, gemm_k, cublas_time, ft_time, ft_speedup)) cublas_time = times[0].item()
reference_result = results[0] ft_time = times[1].item()
ft_result = results[1] ft_speedup = cublas_time / ft_time
print('{},{},{},{},{},{}'.format(
num_rows, gemm_n, gemm_k, cublas_time, ft_time,
ft_speedup))
reference_result = results[0]
ft_result = results[1]
else: else:
if act_str == "only_gemm": if act_str == 'only_gemm':
reference_result = torch.matmul(torch_activations, dequantized_weights) reference_result = torch.matmul(
ft_result = self.fused_gemm_dq(torch_activations, processed_torch_weights, torch_weight_scales) torch_activations, dequantized_weights)
else: ft_result = self.fused_gemm_dq(
reference_result = torch.matmul(torch_activations, dequantized_weights) torch_activations, processed_torch_weights,
reference_result += torch_biases.unsqueeze(0) torch_weight_scales)
reference_result = self.apply_act(reference_result, act_str) else:
reference_result = torch.matmul(
ft_result = self.fused_gemm_dq_bias_act(torch_activations, processed_torch_weights, torch_weight_scales, torch_biases, act_str) torch_activations, dequantized_weights)
reference_result += torch_biases.unsqueeze(0)
msg = "FC1 Failed on m={}, n={}, k={}".format(num_rows, gemm_n, gemm_k) reference_result = self.apply_act(
torch.testing.assert_close(ft_result, reference_result, rtol=rtol, atol=atol, msg=msg, check_dtype=False) reference_result, act_str)
ft_result = self.fused_gemm_dq_bias_act(
torch_activations, processed_torch_weights,
torch_weight_scales, torch_biases, act_str)
msg = 'FC1 Failed on m={}, n={}, k={}'.format(
num_rows, gemm_n, gemm_k)
torch.testing.assert_close(ft_result,
reference_result,
rtol=rtol,
atol=atol,
msg=msg,
check_dtype=False)
def test_fp16_int8_gemm(self): def test_fp16_int8_gemm(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1], torch.float16,
gemm_ns = [1024, 2048, 4096], torch.int8,
gemm_ks = [4096, 8192, 16384], gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
rtol=0.001, atol=0.002) gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.001,
atol=0.002)
def test_fp16_int4_gemm(self): def test_fp16_int4_gemm(self):
self.gemm_dequant_test_helper(torch.float16, torch.quint4x2, self.gemm_dequant_test_helper(
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1], torch.float16,
gemm_ns = [1024, 2048, 4096], torch.quint4x2,
gemm_ks = [4096, 8192, 16384], gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
rtol=0.001, atol=0.002) gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.001,
atol=0.002)
def test_bf16_int8_gemm(self): def test_bf16_int8_gemm(self):
self.gemm_dequant_test_helper(torch.bfloat16, torch.int8, self.gemm_dequant_test_helper(
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1], torch.bfloat16,
gemm_ns = [1024, 2048, 4096], torch.int8,
gemm_ks = [4096, 8192, 16384], gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
rtol=0.01, atol=0.01) gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.01,
atol=0.01)
def test_bf16_int4_gemm(self): def test_bf16_int4_gemm(self):
self.gemm_dequant_test_helper(torch.bfloat16, torch.quint4x2, self.gemm_dequant_test_helper(
gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1], torch.bfloat16,
gemm_ns = [1024, 2048, 4096], torch.quint4x2,
gemm_ks = [4096, 8192, 16384], gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
rtol=0.01, atol=0.01) gemm_ns=[1024, 2048, 4096],
gemm_ks=[4096, 8192, 16384],
rtol=0.01,
atol=0.01)
def test_fp16_int8_gemm_bias(self): def test_fp16_int8_gemm_bias(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(torch.float16,
gemm_ms = [256], torch.int8,
gemm_ns = [1024], gemm_ms=[256],
gemm_ks = [8192], gemm_ns=[1024],
rtol=0.001, atol=0.002, gemm_ks=[8192],
act_str="identity") rtol=0.001,
atol=0.002,
act_str='identity')
def test_fp16_int8_gemm_bias_relu(self): def test_fp16_int8_gemm_bias_relu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(torch.float16,
gemm_ms = [256], torch.int8,
gemm_ns = [1024], gemm_ms=[256],
gemm_ks = [8192], gemm_ns=[1024],
rtol=0.001, atol=0.002, gemm_ks=[8192],
act_str="relu") rtol=0.001,
atol=0.002,
act_str='relu')
def test_fp16_int8_gemm_bias_gelu(self): def test_fp16_int8_gemm_bias_gelu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(torch.float16,
gemm_ms = [256], torch.int8,
gemm_ns = [1024], gemm_ms=[256],
gemm_ks = [8192], gemm_ns=[1024],
rtol=0.001, atol=0.002, gemm_ks=[8192],
act_str="gelu") rtol=0.001,
atol=0.002,
act_str='gelu')
def test_fp16_int8_gemm_bias_silu(self): def test_fp16_int8_gemm_bias_silu(self):
self.gemm_dequant_test_helper(torch.float16, torch.int8, self.gemm_dequant_test_helper(torch.float16,
gemm_ms = [256], torch.int8,
gemm_ns = [1024], gemm_ms=[256],
gemm_ks = [8192], gemm_ns=[1024],
rtol=0.001, atol=0.002, gemm_ks=[8192],
act_str="silu") rtol=0.001,
atol=0.002,
act_str='silu')
def bench_helper(self, act_type, quant_type, rtol, atol): def bench_helper(self, act_type, quant_type, rtol, atol):
# Warm, using bfloat here since it seems to reliably use cublas. # Warm, using bfloat here since it seems to reliably use cublas.
x = random_tensor([20480, 20480], torch.bfloat16, device="cuda") x = random_tensor([20480, 20480], torch.bfloat16, device='cuda')
warm_iters = 30 warm_iters = 30
for iter in range(warm_iters): for iter in range(warm_iters):
res = x @ x res = x @ x
m_shapes = torch.arange(0, 12) m_shapes = torch.arange(0, 12)
m_shapes = 2 ** m_shapes m_shapes = 2**m_shapes
self.gemm_dequant_test_helper(act_type, quant_type, self.gemm_dequant_test_helper(act_type,
gemm_ms = [128], quant_type,
gemm_ns = [1536], gemm_ms=[128],
gemm_ks = [12288], gemm_ns=[1536],
rtol=rtol, atol=atol, benchmark=True) gemm_ks=[12288],
rtol=rtol,
atol=atol,
benchmark=True)
@unittest.skip("This is a benchmark so don't run by default") @unittest.skip("This is a benchmark so don't run by default")
def test_fp16_int8_cublas(self): def test_fp16_int8_cublas(self):
self.bench_helper(torch.float16, torch.int8, 1e-3, 0.002) self.bench_helper(torch.float16, torch.int8, 1e-3, 0.002)
@unittest.skip("This is a benchmark so don't run by default") @unittest.skip("This is a benchmark so don't run by default")
def test_bf16_int8_cublas(self): def test_bf16_int8_cublas(self):
self.bench_helper(torch.bfloat16, torch.int8, 1e-2, 1e-2) self.bench_helper(torch.bfloat16, torch.int8, 1e-2, 1e-2)
@unittest.skip("This is a benchmark so don't run by default") @unittest.skip("This is a benchmark so don't run by default")
def test_fp16_int4_cublas(self): def test_fp16_int4_cublas(self):
self.bench_helper(torch.float16, torch.quint4x2, 1e-3, 0.002) self.bench_helper(torch.float16, torch.quint4x2, 1e-3, 0.002)
@unittest.skip("This is a benchmark so don't run by default") @unittest.skip("This is a benchmark so don't run by default")
def test_bf16_int4_cublas(self): def test_bf16_int4_cublas(self):
self.bench_helper(torch.bfloat16, torch.quint4x2, 1e-2, 1e-2) self.bench_helper(torch.bfloat16, torch.quint4x2, 1e-2, 1e-2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
\ No newline at end of file
...@@ -21,4 +21,4 @@ add_definitions(-DTORCH_CUDA=1) ...@@ -21,4 +21,4 @@ add_definitions(-DTORCH_CUDA=1)
set(EXE_NAME "int8_gemm_test") set(EXE_NAME "int8_gemm_test")
add_executable(${EXE_NAME} ${int8_test_files}) add_executable(${EXE_NAME} ${int8_test_files})
set_target_properties(${EXE_NAME} PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_target_properties(${EXE_NAME} PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(${EXE_NAME} PUBLIC "${TORCH_LIBRARIES}" int8_gemm tensor logger) target_link_libraries(${EXE_NAME} PUBLIC "${TORCH_LIBRARIES}" int8_gemm tensor logger)
\ No newline at end of file
...@@ -38,9 +38,9 @@ namespace ft = fastertransformer; ...@@ -38,9 +38,9 @@ namespace ft = fastertransformer;
template<typename T> template<typename T>
void int8_gemm_test( void int8_gemm_test(
const int m, const int m,
const int n, const int n,
const int k, const int k,
const at::ScalarType output_data_type, const at::ScalarType output_data_type,
const QuantMode quant_mode, const QuantMode quant_mode,
const int iters) const int iters)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment