Commit 9484fd1c authored by xiabo's avatar xiabo
Browse files

Adapt to 0.1.0

parent 477f2db8
...@@ -185,124 +185,126 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -185,124 +185,126 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
if (findAlgo) { if (findAlgo) {
if (info.stages != -1) { if (info.stages != -1) {
using_cublasLt = true; using_cublasLt = false;
} }
else { else {
using_cublasLt = false; using_cublasLt = false;
} }
} }
if (using_cublasLt) {
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (is_fp16_computeType) {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
}
else {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
#if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
cublasLtMatmulDescCreate(&operationDesc, computeType);
#endif
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)); // if (using_cublasLt) {
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)); // if (0) {
// cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatmulAlgo_t algo; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
void* workSpace = cublas_workspace_; // cudaDataType_t scaleType;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; // #if (CUDART_VERSION >= 11000)
if (findAlgo) { // cublasComputeType_t computeType;
if (info.workspaceSize > workspaceSize) { // #else
findAlgo = 0; // cudaDataType_t computeType;
} // #endif
else {
cublasLtMatmulAlgoInit( // if (is_fp16_computeType) {
cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo); // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // computeType = CUBLAS_COMPUTE_16F;
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); // #else
cublasLtMatmulAlgoConfigSetAttribute( // computeType = CUDA_R_16F;
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); // #endif
cublasLtMatmulAlgoConfigSetAttribute( // scaleType = CUDA_R_16F;
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); // }
cublasLtMatmulAlgoConfigSetAttribute( // else {
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute(&algo, // computeType = CUBLAS_COMPUTE_32F;
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, // #else
&(info.reductionScheme), // computeType = CUDA_R_32F;
sizeof(info.reductionScheme)); // #endif
// scaleType = CUDA_R_32F;
#if (CUDART_VERSION >= 11000) // }
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // // --------------------------------------
#endif // // Create descriptors for the original matrices
// cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) // cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
&algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)); // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, // #else
&(info.cluster_shapeId), // cublasLtMatmulDescCreate(&operationDesc, computeType);
sizeof(info.cluster_shapeId)); // #endif
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
&algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)); // cublasLtMatmulAlgo_t algo;
cublasLtMatmulAlgoConfigSetAttribute( // void* workSpace = cublas_workspace_;
&algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode)); // int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
#endif // if (findAlgo) {
} // if (info.workspaceSize > workspaceSize) {
} // findAlgo = 0;
// }
cublasLtMatmul(cublaslt_handle_, // else {
operationDesc, // cublasLtMatmulAlgoInit(
alpha, // cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo);
A, // cublasLtMatmulAlgoConfigSetAttribute(
Adesc, // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
B, // cublasLtMatmulAlgoConfigSetAttribute(
Bdesc, // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
beta, // cublasLtMatmulAlgoConfigSetAttribute(
C, // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
Cdesc, // cublasLtMatmulAlgoConfigSetAttribute(
C, // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
Cdesc, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
(findAlgo == 1 ? (&algo) : NULL), // CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
workSpace, // &(info.reductionScheme),
workspaceSize, // sizeof(info.reductionScheme));
stream_);
// // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescDestroy(operationDesc); // // cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatrixLayoutDestroy(Adesc); // // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
cublasLtMatrixLayoutDestroy(Bdesc); // // #endif
cublasLtMatrixLayoutDestroy(Cdesc);
sync_check_cuda_error(); // #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
} // cublasLtMatmulAlgoConfigSetAttribute(
else { // &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(&algo,
// CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID,
// &(info.cluster_shapeId),
// sizeof(info.cluster_shapeId));
// #elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode));
// #endif
// }
// }
// // cublasLtMatmul(cublaslt_handle_,
// // operationDesc,
// // alpha,
// // A,
// // Adesc,
// // B,
// // Bdesc,
// // beta,
// // C,
// // Cdesc,
// // C,
// // Cdesc,
// // (findAlgo == 1 ? (&algo) : NULL),
// // workSpace,
// // workspaceSize,
// // stream_);
// cublasLtMatmulDescDestroy(operationDesc);
// cublasLtMatrixLayoutDestroy(Adesc);
// cublasLtMatrixLayoutDestroy(Bdesc);
// cublasLtMatrixLayoutDestroy(Cdesc);
// sync_check_cuda_error();
// }
// else {
int cublasAlgo = info.algoId; int cublasAlgo = info.algoId;
check_cuda_error(cublasGemmEx(cublas_handle_, check_cuda_error(cublasGemmEx(cublas_handle_,
transa, transa,
...@@ -324,7 +326,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -324,7 +326,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
computeType_, computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo))); static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error(); sync_check_cuda_error();
} // }
mu_->unlock(); mu_->unlock();
} }
...@@ -341,7 +343,7 @@ void cublasMMWrapper::setFP16GemmConfig() ...@@ -341,7 +343,7 @@ void cublasMMWrapper::setFP16GemmConfig()
Atype_ = CUDA_R_16F; Atype_ = CUDA_R_16F;
Btype_ = CUDA_R_16F; Btype_ = CUDA_R_16F;
Ctype_ = CUDA_R_16F; Ctype_ = CUDA_R_16F;
computeType_ = CUDA_R_32F; computeType_ = CUDA_R_16F;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
...@@ -381,81 +383,81 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type) ...@@ -381,81 +383,81 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
return FLOAT_DATATYPE; return FLOAT_DATATYPE;
} }
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
// input, weight, output are row-major // // input, weight, output are row-major
// only works for cublas 11.x // // only works for cublas 11.x
void cublasMMWrapper::Gemm(cublasOperation_t transa, // void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasOperation_t transb, // cublasOperation_t transb,
const int m, // const int m,
const int n, // const int n,
const int k, // const int k,
const void* A, // const void* A,
const int lda, // const int lda,
const void* B, // const void* B,
const int ldb, // const int ldb,
const void* bias, // const void* bias,
void* C, // void* C,
const int ldc) // const int ldc)
{ // {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); // TM_LOG_DEBUG(__PRETTY_FUNCTION__);
cudaDataType_t Atype, Btype, Ctype; // cudaDataType_t Atype, Btype, Ctype;
cublasComputeType_t computeType; // cublasComputeType_t computeType;
cudaDataType_t scaleType; // cudaDataType_t scaleType;
float alpha_float = 1.0f; // float alpha_float = 1.0f;
float beta_float = 0.0f; // float beta_float = 0.0f;
half alpha_half = half(1.0f); // half alpha_half = half(1.0f);
half beta_half = half(0.0f); // half beta_half = half(0.0f);
void * alpha, *beta; // void * alpha, *beta;
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; // // int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
if (Atype_ == CUDA_R_32F) { // if (Atype_ == CUDA_R_32F) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32; // computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_32F; // Atype = CUDA_R_32F;
Btype = CUDA_R_32F; // Btype = CUDA_R_32F;
Ctype = CUDA_R_32F; // Ctype = CUDA_R_32F;
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
alpha = &alpha_float; // alpha = &alpha_float;
beta = &beta_float; // beta = &beta_float;
} // }
else if (Atype_ == CUDA_R_16BF) { // else if (Atype_ == CUDA_R_16BF) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32; // computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_16BF; // Atype = CUDA_R_16BF;
Btype = CUDA_R_16BF; // Btype = CUDA_R_16BF;
Ctype = CUDA_R_16BF; // Ctype = CUDA_R_16BF;
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
alpha = &alpha_float; // alpha = &alpha_float;
beta = &beta_float; // beta = &beta_float;
} // }
else { // else {
computeType = CUBLAS_COMPUTE_16F; // computeType = CUBLAS_COMPUTE_16F;
Atype = CUDA_R_16F; // Atype = CUDA_R_16F;
Btype = CUDA_R_16F; // Btype = CUDA_R_16F;
Ctype = CUDA_R_16F; // Ctype = CUDA_R_16F;
scaleType = CUDA_R_16F; // scaleType = CUDA_R_16F;
alpha = &alpha_half; // alpha = &alpha_half;
beta = &beta_half; // beta = &beta_half;
} // }
cublasLtMatmulDesc_t operationDesc = NULL; // cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; // cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda); // cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb); // cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc); // cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc);
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); // cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
check_cuda_error(cublasLtMatmul( // // check_cuda_error(cublasLtMatmul(
cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_)); // // cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_));
cublasLtMatrixLayoutDestroy(Adesc); // cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc); // cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc); // cublasLtMatrixLayoutDestroy(Cdesc);
cublasLtMatmulDescDestroy(operationDesc); // cublasLtMatmulDescDestroy(operationDesc);
} // }
#endif // #endif
void cublasMMWrapper::setStream(cudaStream_t stream) void cublasMMWrapper::setStream(cudaStream_t stream)
{ {
stream_ = stream; stream_ = stream;
...@@ -985,7 +987,8 @@ void cublasMMWrapper::_Int8Gemm(const int m, ...@@ -985,7 +987,8 @@ void cublasMMWrapper::_Int8Gemm(const int m,
* - 0: int8 * int8 -> int32 -> int8 * - 0: int8 * int8 -> int32 -> int8
* - 1: int8 * int8 -> int32 -> int32 * - 1: int8 * int8 -> int32 -> int32
*/ */
#if (CUBLAS_VERSION) <= 11601 // #if (CUBLAS_VERSION) <= 11601
#if 1
FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); FT_CHECK_WITH_INFO(false, "CUBLAS version too low.");
#else #else
......
...@@ -207,20 +207,20 @@ public: ...@@ -207,20 +207,20 @@ public:
CublasDataType getCublasDataType(cudaDataType_t data_type); CublasDataType getCublasDataType(cudaDataType_t data_type);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
void Gemm(cublasOperation_t transa, // void Gemm(cublasOperation_t transa,
cublasOperation_t transb, // cublasOperation_t transb,
const int m, // const int m,
const int n, // const int n,
const int k, // const int k,
const void* A, // const void* A,
const int lda, // const int lda,
const void* B, // const void* B,
const int ldb, // const int ldb,
const void* bias, // const void* bias,
void* C, // void* C,
const int ldc); // const int ldc);
#endif // #endif
void stridedBatchedGemm(cublasOperation_t transa, void stridedBatchedGemm(cublasOperation_t transa,
cublasOperation_t transb, cublasOperation_t transb,
......
...@@ -322,7 +322,7 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val) ...@@ -322,7 +322,7 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val)
int16_t int16_in; int16_t int16_in;
}; };
fp16 = val; fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); // asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0]; return int8[0];
} }
...@@ -333,20 +333,31 @@ __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) ...@@ -333,20 +333,31 @@ __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
int8[0] = cuda_cast<int8_t>(val.x); // int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y); // int8[1] = cuda_cast<int8_t>(val.y);
int8[0] = cuda_cast<int8_t>((val.data[0]));
int8[1] = cuda_cast<int8_t>((val.data[1]));
return int16; return int16;
} }
template<> template<>
__device__ inline int8_t cuda_cast<int8_t, float>(float val) __device__ inline int8_t cuda_cast<int8_t, float>(float val)
{ {
union { // union {
int8_t int8[2]; // int8_t int8[2];
int16_t int16; // int16_t int16;
}; // };
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); // asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0]; // return int8[0];
int8_t dst;
if (val >= 128){
dst = 127;
}else if (val < -128){
dst = -128;
}else{
dst = static_cast<int8_t>(val);
}
return dst;
} }
template<> template<>
...@@ -528,7 +539,8 @@ __device__ inline To cuda_max(Ti val) ...@@ -528,7 +539,8 @@ __device__ inline To cuda_max(Ti val)
template<> template<>
__device__ inline half cuda_max(half2 val) __device__ inline half cuda_max(half2 val)
{ {
return (val.x > val.y) ? val.x : val.y; // return (val.x > val.y) ? val.x : val.y;
return (val.data[0] > val.data[1]) ? val.data[0] : val.data[1];
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template<> template<>
......
...@@ -152,17 +152,17 @@ void initCustomAllReduceComm(std::vector<std::shared_ptr<AbstractCustomComm>>* c ...@@ -152,17 +152,17 @@ void initCustomAllReduceComm(std::vector<std::shared_ptr<AbstractCustomComm>>* c
return; return;
} }
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020 // #if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
for (size_t i = 0; i < rank_size; i++) { // for (size_t i = 0; i < rank_size; i++) {
custom_all_reduce_comms->push_back(std::make_shared<CustomAllReduceComm<T>>(rank_size, i)); // custom_all_reduce_comms->push_back(std::make_shared<CustomAllReduceComm<T>>(rank_size, i));
} // }
custom_all_reduce_comms->at(0)->allocateAndExchangePeerAccessPointer(custom_all_reduce_comms); // custom_all_reduce_comms->at(0)->allocateAndExchangePeerAccessPointer(custom_all_reduce_comms);
#else // #else
TM_LOG_WARNING("Custom All Reduce is not supported before CUDA 11.2. Using NCCL as Comm."); TM_LOG_WARNING("Custom All Reduce is not supported before CUDA 11.2. Using NCCL as Comm.");
for (size_t i = 0; i < rank_size; i++) { for (size_t i = 0; i < rank_size; i++) {
custom_all_reduce_comms->push_back(nullptr); custom_all_reduce_comms->push_back(nullptr);
} }
#endif // #endif
} }
// Template instantiation // Template instantiation
......
...@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file) ...@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file)
stream_ = stream; stream_ = stream;
mutex_ = new std::mutex(); // mutex per process mutex_ = new std::mutex(); // mutex per process
check_cuda_error(cublasCreate(&cublas_handle_)); check_cuda_error(cublasCreate(&cublas_handle_));
check_cuda_error(cublasLtCreate(&cublaslt_handle_)); // check_cuda_error(cublasLtCreate(&cublaslt_handle_));
check_cuda_error(cublasSetStream(cublas_handle_, stream)); check_cuda_error(cublasSetStream(cublas_handle_, stream));
if (allocator_ != nullptr) { if (allocator_ != nullptr) {
...@@ -41,7 +41,7 @@ Gemm::~Gemm() ...@@ -41,7 +41,7 @@ Gemm::~Gemm()
allocator_->free((void**)(&workspace_)); allocator_->free((void**)(&workspace_));
allocator_ = nullptr; allocator_ = nullptr;
} }
cublasLtDestroy(cublaslt_handle_); // cublasLtDestroy(cublaslt_handle_);
cublasDestroy(cublas_handle_); cublasDestroy(cublas_handle_);
delete cublas_algo_map_; delete cublas_algo_map_;
delete mutex_; delete mutex_;
...@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa, ...@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa,
mutex_->lock(); mutex_->lock();
// Use cublas as default in FP32 and cublasLt as default in FP16 // Use cublas as default in FP32 and cublasLt as default in FP16
bool is_fp16_compute_type = compute_type_ == TYPE_FP16; bool is_fp16_compute_type = compute_type_ == TYPE_FP16;
bool using_cublasLt = Atype == TYPE_FP16; // bool using_cublasLt = Atype == TYPE_FP16;
bool using_cublasLt = (Atype == TYPE_FP16) ? false : false;
int batch_count = 1; int batch_count = 1;
half h_alpha = (half)alpha; half h_alpha = (half)alpha;
...@@ -267,82 +268,83 @@ void Gemm::gemm(const GemmOp transa, ...@@ -267,82 +268,83 @@ void Gemm::gemm(const GemmOp transa,
using_cublasLt = (info.stages != -1); using_cublasLt = (info.stages != -1);
} }
if (using_cublasLt) { // if (using_cublasLt) {
const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k; // if(0) {
const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m; // const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k;
const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n; // const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m;
const size_t b_cols = (b_op == getCublasOperation(GEMM_OP_N)) ? _n : k; // const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n;
// const size_t b_cols = (b_op == getCublasOperation(GEMM_OP_N)) ? _n : k;
cublasLtMatmulDesc_t matmul_desc = NULL;
cublasLtMatrixLayout_t a_desc = NULL, b_desc = NULL, c_desc = NULL; // cublasLtMatmulDesc_t matmul_desc = NULL;
cudaDataType_t scale_type = getCublasDataType(compute_type_); // cublasLtMatrixLayout_t a_desc = NULL, b_desc = NULL, c_desc = NULL;
auto compute_type = getCublasComputeType(compute_type_); // cudaDataType_t scale_type = getCublasDataType(compute_type_);
// auto compute_type = getCublasComputeType(compute_type_);
// --------------------------------------
// Create descriptors for the original matrices // // --------------------------------------
cublasLtMatrixLayoutCreate(&a_desc, a_type, a_rows, a_cols, _lda); // // Create descriptors for the original matrices
cublasLtMatrixLayoutCreate(&b_desc, b_type, b_rows, b_cols, _ldb); // cublasLtMatrixLayoutCreate(&a_desc, a_type, a_rows, a_cols, _lda);
cublasLtMatrixLayoutCreate(&c_desc, c_type, _m, _n, ldc); // cublasLtMatrixLayoutCreate(&b_desc, b_type, b_rows, b_cols, _ldb);
#if (CUDART_VERSION >= 11000) // cublasLtMatrixLayoutCreate(&c_desc, c_type, _m, _n, ldc);
cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_type); // #if (CUDART_VERSION >= 11000)
#else // cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_type);
cublasLtMatmulDescCreate(&matmul_desc, compute_type); // #else
#endif // cublasLtMatmulDescCreate(&matmul_desc, compute_type);
// #endif
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &a_op, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &a_op, sizeof(cublasOperation_t));
// cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t));
cublasLtMatmulAlgo_t algo;
void* workspace = workspace_; // cublasLtMatmulAlgo_t algo;
int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE; // void* workspace = workspace_;
if (findAlgo) { // int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE;
if (info.workspaceSize > workspace_size) { // if (findAlgo) {
findAlgo = 0; // if (info.workspaceSize > workspace_size) {
} // findAlgo = 0;
else { // }
cublasLtMatmulAlgoInit( // else {
cublaslt_handle_, compute_type, scale_type, a_type, b_type, c_type, c_type, info.algoId, &algo); // cublasLtMatmulAlgoInit(
cublasLtMatmulAlgoConfigSetAttribute( // cublaslt_handle_, compute_type, scale_type, a_type, b_type, c_type, c_type, info.algoId, &algo);
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); // cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); // cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); // cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); // cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int)); // cublasLtMatmulAlgoConfigSetAttribute(
#if (CUDART_VERSION >= 11000) // &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int));
cublasLtMatmulAlgoConfigSetAttribute( // #if (CUDART_VERSION >= 11000)
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // cublasLtMatmulAlgoConfigSetAttribute(
#endif // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
} // #endif
} // }
// }
cublasLtMatmul(cublaslt_handle_,
matmul_desc, // cublasLtMatmul(cublaslt_handle_,
alpha_ptr, // matmul_desc,
a_data_ptr, // alpha_ptr,
a_desc, // a_data_ptr,
b_data_ptr, // a_desc,
b_desc, // b_data_ptr,
beta_ptr, // b_desc,
C, // beta_ptr,
c_desc, // C,
C, // c_desc,
c_desc, // C,
(findAlgo == 1 ? (&algo) : NULL), // c_desc,
workspace, // (findAlgo == 1 ? (&algo) : NULL),
workspace_size, // workspace,
stream_); // workspace_size,
// stream_);
cublasLtMatmulDescDestroy(matmul_desc);
cublasLtMatrixLayoutDestroy(a_desc); // cublasLtMatmulDescDestroy(matmul_desc);
cublasLtMatrixLayoutDestroy(b_desc); // cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatrixLayoutDestroy(c_desc); // cublasLtMatrixLayoutDestroy(b_desc);
sync_check_cuda_error(); // cublasLtMatrixLayoutDestroy(c_desc);
} // sync_check_cuda_error();
else { // }
// else {
cudaDataType_t compute_type = getCublasDataType(compute_type_); cudaDataType_t compute_type = getCublasDataType(compute_type_);
int cublas_algo = info.algoId; int cublas_algo = info.algoId;
check_cuda_error(cublasGemmEx(cublas_handle_, check_cuda_error(cublasGemmEx(cublas_handle_,
...@@ -365,7 +367,7 @@ void Gemm::gemm(const GemmOp transa, ...@@ -365,7 +367,7 @@ void Gemm::gemm(const GemmOp transa,
compute_type, compute_type,
static_cast<cublasGemmAlgo_t>(cublas_algo))); static_cast<cublasGemmAlgo_t>(cublas_algo)));
sync_check_cuda_error(); sync_check_cuda_error();
} // }
mutex_->unlock(); mutex_->unlock();
} }
...@@ -1033,19 +1035,19 @@ cudaDataType_t getCublasDataType(DataType dtype) ...@@ -1033,19 +1035,19 @@ cudaDataType_t getCublasDataType(DataType dtype)
} }
} }
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t getCublasComputeType(DataType ctype) // cublasComputeType_t getCublasComputeType(DataType ctype)
{ // {
switch (ctype) { // switch (ctype) {
case TYPE_FP16: // case TYPE_FP16:
return CUBLAS_COMPUTE_16F; // return CUBLAS_COMPUTE_16F;
case TYPE_FP32: // case TYPE_FP32:
return CUBLAS_COMPUTE_32F; // return CUBLAS_COMPUTE_32F;
default: // default:
throw GemmNotSupportedException("Not supported cublas compute type."); // throw GemmNotSupportedException("Not supported cublas compute type.");
} // }
} // }
#else // #else
cudaDataType_t getCublasComputeType(DataType ctype) cudaDataType_t getCublasComputeType(DataType ctype)
{ {
switch (ctype) { switch (ctype) {
...@@ -1057,7 +1059,7 @@ cudaDataType_t getCublasComputeType(DataType ctype) ...@@ -1057,7 +1059,7 @@ cudaDataType_t getCublasComputeType(DataType ctype)
throw GemmNotSupportedException("Not supported cublas compute type."); throw GemmNotSupportedException("Not supported cublas compute type.");
} }
} }
#endif // #endif
cublasOperation_t getCublasOperation(GemmOp op) cublasOperation_t getCublasOperation(GemmOp op)
{ {
......
...@@ -622,11 +622,11 @@ std::shared_ptr<Gemm> ...@@ -622,11 +622,11 @@ std::shared_ptr<Gemm>
createGemm(IAllocator* allocator, cudaStream_t stream, bool sparse = false, bool quantized = false); createGemm(IAllocator* allocator, cudaStream_t stream, bool sparse = false, bool quantized = false);
cudaDataType_t getCublasDataType(DataType dtype); cudaDataType_t getCublasDataType(DataType dtype);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t getCublasComputeType(DataType dtype); // cublasComputeType_t getCublasComputeType(DataType dtype);
#else // #else
cudaDataType_t getCublasComputeType(DataType dtype); cudaDataType_t getCublasComputeType(DataType dtype);
#endif // #endif
cublasOperation_t getCublasOperation(GemmOp op); cublasOperation_t getCublasOperation(GemmOp op);
std::string getGemmOpString(const GemmOp& op); std::string getGemmOpString(const GemmOp& op);
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
set(gemm_func_files set(gemm_func_files
gemm_func.cc gemm_func.cc
...@@ -51,59 +52,71 @@ set(swin_gemm_func_files ...@@ -51,59 +52,71 @@ set(swin_gemm_func_files
swin_gemm_func.cc swin_gemm_func.cc
) )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(gemm_func STATIC ${gemm_func_files}) add_library(gemm_func STATIC ${gemm_func_files})
target_link_libraries(gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart cuda_utils logger) #target_link_libraries(gemm_func PUBLIC cublas cublasLt cudart cuda_utils logger)
set_property(TARGET gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(gemm_func PUBLIC cublas cudart cuda_utils logger)
set_property(TARGET gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(encoder_gemm_func STATIC ${encoder_gemm_func_files}) add_library(encoder_gemm_func STATIC ${encoder_gemm_func_files})
target_link_libraries(encoder_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(encoder_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries(encoder_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(encoder_gemm_func PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(encoder_gemm_func PUBLIC cusparse -lcusparseLt)
endif() endif()
set_property(TARGET encoder_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET encoder_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET encoder_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET encoder_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(encoder_igemm_func STATIC ${encoder_igemm_func_files}) add_library(encoder_igemm_func STATIC ${encoder_igemm_func_files})
target_link_libraries(encoder_igemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart cuda_utils logger) #target_link_libraries(encoder_igemm_func PUBLIC cublas cublasLt cudart cuda_utils logger)
target_link_libraries(encoder_igemm_func PUBLIC cublas cudart cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(encoder_igemm_func PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(encoder_igemm_func PUBLIC cusparse -lcusparseLt)
endif() endif()
set_property(TARGET encoder_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET encoder_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET encoder_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET encoder_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(decoding_gemm_func STATIC ${decoding_gemm_func_files}) add_library(decoding_gemm_func STATIC ${decoding_gemm_func_files})
target_link_libraries(decoding_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(decoding_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property(TARGET decoding_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(decoding_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
set_property(TARGET decoding_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET decoding_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET decoding_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(gpt_gemm_func STATIC ${gpt_gemm_func_files}) add_library(gpt_gemm_func STATIC ${gpt_gemm_func_files})
target_link_libraries(gpt_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(gpt_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries(gpt_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(gpt_gemm_func PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(gpt_gemm_func PUBLIC cusparse -lcusparseLt)
endif() endif()
set_property(TARGET gpt_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET gpt_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gpt_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET gpt_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(xlnet_gemm_func STATIC ${xlnet_gemm_func_files}) add_library(xlnet_gemm_func STATIC ${xlnet_gemm_func_files})
target_link_libraries(xlnet_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(xlnet_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property(TARGET xlnet_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(xlnet_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
set_property(TARGET xlnet_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET xlnet_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET xlnet_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(t5_gemm_func STATIC ${t5_gemm_func_files}) add_library(t5_gemm_func STATIC ${t5_gemm_func_files})
target_link_libraries(t5_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(t5_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries(t5_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(t5_gemm_func PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(t5_gemm_func PUBLIC cusparse -lcusparseLt)
endif() endif()
set_property(TARGET t5_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET t5_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET t5_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET t5_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(swin_igemm_func STATIC ${swin_igemm_func_files}) add_library(swin_igemm_func STATIC ${swin_igemm_func_files})
target_link_libraries(swin_igemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func encoder_igemm_func cuda_utils logger) #target_link_libraries(swin_igemm_func PUBLIC cublas cublasLt cudart gemm_func encoder_igemm_func cuda_utils logger)
set_property(TARGET swin_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(swin_igemm_func PUBLIC cublas cudart gemm_func encoder_igemm_func cuda_utils logger)
set_property(TARGET swin_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET swin_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET swin_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(swin_gemm_func STATIC ${swin_gemm_func_files}) add_library(swin_gemm_func STATIC ${swin_gemm_func_files})
target_link_libraries(swin_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(swin_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property(TARGET swin_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(swin_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
set_property(TARGET swin_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET swin_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET swin_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
...@@ -130,8 +130,8 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -130,8 +130,8 @@ void generate_decoding_gemm_config(int batch_size,
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -148,16 +148,19 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -148,16 +148,19 @@ void generate_decoding_gemm_config(int batch_size,
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -166,11 +169,14 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -166,11 +169,14 @@ void generate_decoding_gemm_config(int batch_size,
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T>::Type; // using scaleT = typename ScaleTypeConverter<T>::Type;
using scaleT = typename ScaleTypeConverter<T, true>::Type;
scaleT alpha = (scaleT)1.0f; scaleT alpha = (scaleT)1.0f;
scaleT beta = (scaleT)0.0f; scaleT beta = (scaleT)0.0f;
...@@ -241,38 +247,39 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -241,38 +247,39 @@ void generate_decoding_gemm_config(int batch_size,
const int ALGO_COMBINATIONS = 5000; const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size * beam_width, // batch_size * beam_width,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure(batch_size * beam_width, // printPerfStructure(batch_size * beam_width,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
perfResults[0], // perfResults[0],
fd, // fd,
data_type, // data_type,
0); // 0);
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
...@@ -127,8 +127,8 @@ void generate_encoder_gemm_config( ...@@ -127,8 +127,8 @@ void generate_encoder_gemm_config(
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -145,16 +145,19 @@ void generate_encoder_gemm_config( ...@@ -145,16 +145,19 @@ void generate_encoder_gemm_config(
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -163,11 +166,14 @@ void generate_encoder_gemm_config( ...@@ -163,11 +166,14 @@ void generate_encoder_gemm_config(
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T, false>::Type; // using scaleT = typename ScaleTypeConverter<T, false>::Type;
using scaleT = typename ScaleTypeConverter<T, true>::Type;
scaleT alpha = (scaleT)1.0f; scaleT alpha = (scaleT)1.0f;
scaleT beta = (scaleT)0.0f; scaleT beta = (scaleT)0.0f;
...@@ -331,30 +337,31 @@ void generate_encoder_gemm_config( ...@@ -331,30 +337,31 @@ void generate_encoder_gemm_config(
// Let try a fixed number of combinations // Let try a fixed number of combinations
const int ALGO_COMBINATIONS = 5000; const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size, // batch_size,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure( // printPerfStructure(
batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); // batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time = perfResults[0].time; // exec_time = perfResults[0].time;
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
This diff is collapsed.
...@@ -223,8 +223,8 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -223,8 +223,8 @@ void generate_gpt_gemm_config(int batch_size,
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -244,7 +244,8 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -244,7 +244,8 @@ void generate_gpt_gemm_config(int batch_size,
DType = CUDA_R_32F; DType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
...@@ -252,9 +253,11 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -252,9 +253,11 @@ void generate_gpt_gemm_config(int batch_size,
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
DType = CUDA_R_16F; DType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -264,8 +267,10 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -264,8 +267,10 @@ void generate_gpt_gemm_config(int batch_size,
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
DType = CUDA_R_16BF; DType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
...@@ -293,12 +298,24 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -293,12 +298,24 @@ void generate_gpt_gemm_config(int batch_size,
DType_FP8[9] = CUDA_R_16BF; DType_FP8[9] = CUDA_R_16BF;
#endif #endif
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
float alpha = (float)1.0f; // float alpha = (float)1.0f;
float beta = (float)0.0f; // float beta = (float)0.0f;
float f_alpha = (float)1.0f;
float f_beta = (float)0.0f;
half h_alpha = (half)(f_alpha);
half h_beta = (half)(f_beta);
int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0;
const void* alpha = is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
printf("***Encoder Gemm Testing Begin***\n"); printf("***Encoder Gemm Testing Begin***\n");
printf("***Cublas Gemm Testing Begin***\n"); printf("***Cublas Gemm Testing Begin***\n");
...@@ -342,7 +359,7 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -342,7 +359,7 @@ void generate_gpt_gemm_config(int batch_size,
max_input_len, max_input_len,
max_input_len, max_input_len,
size_per_head, size_per_head,
&alpha, &f_alpha,
d_B, d_B,
BType, BType,
size_per_head, size_per_head,
...@@ -351,13 +368,13 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -351,13 +368,13 @@ void generate_gpt_gemm_config(int batch_size,
AType, AType,
size_per_head, size_per_head,
max_input_len * size_per_head, max_input_len * size_per_head,
&beta, &f_beta,
d_C, d_C,
CUDA_R_32F, // CType, CUDA_R_32F, // CType,
max_input_len, max_input_len,
max_input_len * max_input_len, max_input_len * max_input_len,
batchCount[i], batchCount[i],
computeType, CUDA_R_32F,
static_cast<cublasGemmAlgo_t>(algo)); static_cast<cublasGemmAlgo_t>(algo));
} }
else if (i == 2) { else if (i == 2) {
...@@ -456,44 +473,45 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -456,44 +473,45 @@ void generate_gpt_gemm_config(int batch_size,
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
// for gpt, computeType & scaleType should be FP32 // for gpt, computeType & scaleType should be FP32
LtHgemmCustomFind<T, float>(ltHandle, // LtHgemmCustomFind<T, float>(ltHandle,
batch_size * beam_width, // batch_size * beam_width,
i == 1 || i == 2 ? max_input_len : 1, // i == 1 || i == 2 ? max_input_len : 1,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS, // ALGO_COMBINATIONS,
DType_FP8[i], // DType_FP8[i],
batchCount[i], // batchCount[i],
strideA[i], // strideA[i],
strideB[i], // strideB[i],
strideD[i]); // strideD[i]);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure(batch_size * beam_width, // printPerfStructure(batch_size * beam_width,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
perfResults[0], // perfResults[0],
fd, // fd,
data_type, // data_type,
0, // 0,
batchCount[i]); // batchCount[i]);
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
...@@ -133,8 +133,8 @@ void generate_swin_gemm_config( ...@@ -133,8 +133,8 @@ void generate_swin_gemm_config(
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -151,16 +151,19 @@ void generate_swin_gemm_config( ...@@ -151,16 +151,19 @@ void generate_swin_gemm_config(
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -169,11 +172,14 @@ void generate_swin_gemm_config( ...@@ -169,11 +172,14 @@ void generate_swin_gemm_config(
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T, false>::Type; // using scaleT = typename ScaleTypeConverter<T, false>::Type;
using scaleT = typename ScaleTypeConverter<T, true>::Type;
scaleT alpha = (scaleT)1.0f; scaleT alpha = (scaleT)1.0f;
scaleT beta = (scaleT)0.0f; scaleT beta = (scaleT)0.0f;
...@@ -309,30 +315,31 @@ void generate_swin_gemm_config( ...@@ -309,30 +315,31 @@ void generate_swin_gemm_config(
const int ALGO_COMBINATIONS = 5000; const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size, // batch_size,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure( // printPerfStructure(
batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); // batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time = perfResults[0].time; // exec_time = perfResults[0].time;
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
...@@ -144,23 +144,23 @@ int igemm_config_INT8IO(int m, int n, int k, FILE* fout, void* buffer) ...@@ -144,23 +144,23 @@ int igemm_config_INT8IO(int m, int n, int k, FILE* fout, void* buffer)
int8_t* d_B = d_A + m * k; // k * n, stored in column-major int8_t* d_B = d_A + m * k; // k * n, stored in column-major
int8_t* d_C = (int8_t*)(d_B + k * n); // m * n, stored in column-major int8_t* d_C = (int8_t*)(d_B + k * n); // m * n, stored in column-major
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
cublasLtCreate(&ltHandle); // cublasLtCreate(&ltHandle);
LtIgemmCustomFind(ltHandle, // LtIgemmCustomFind(ltHandle,
m, // m,
n, // n,
k, // k,
&alpha, /* host pointer */ // &alpha, /* host pointer */
d_A, // d_A,
d_B, // d_B,
&beta, /* host pointer */ // &beta, /* host pointer */
d_C, // d_C,
NULL, // NULL,
0, // 0,
fout); // fout);
cublasLtDestroy(ltHandle); // cublasLtDestroy(ltHandle);
return 0; return 0;
} }
......
...@@ -195,8 +195,8 @@ void generate_t5_gemm_config(int batch_size, ...@@ -195,8 +195,8 @@ void generate_t5_gemm_config(int batch_size,
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -213,16 +213,19 @@ void generate_t5_gemm_config(int batch_size, ...@@ -213,16 +213,19 @@ void generate_t5_gemm_config(int batch_size,
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -231,8 +234,10 @@ void generate_t5_gemm_config(int batch_size, ...@@ -231,8 +234,10 @@ void generate_t5_gemm_config(int batch_size,
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
float f_alpha = (float)1.0f; float f_alpha = (float)1.0f;
...@@ -442,60 +447,61 @@ void generate_t5_gemm_config(int batch_size, ...@@ -442,60 +447,61 @@ void generate_t5_gemm_config(int batch_size,
scaleT alpha_scale = (scaleT)1.0f; scaleT alpha_scale = (scaleT)1.0f;
scaleT beta_scale = (scaleT)0.0f; scaleT beta_scale = (scaleT)0.0f;
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
m, // m,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&(alpha_scale), // &(alpha_scale),
d_B, // d_B,
d_A, // d_A,
&(beta_scale), // &(beta_scale),
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
} }
else { else {
LtHgemmCustomFind<T, float>(ltHandle, // LtHgemmCustomFind<T, float>(ltHandle,
m, // m,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&(f_alpha), // &(f_alpha),
d_B, // d_B,
d_A, // d_A,
&(f_beta), // &(f_beta),
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
} }
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure(batch_size * (i <= 5 || i == 1 ? 1 : beam_width), // printPerfStructure(batch_size * (i <= 5 || i == 1 ? 1 : beam_width),
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
perfResults[0], // perfResults[0],
fd, // fd,
data_type, // data_type,
0); // 0);
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
...@@ -218,8 +218,8 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -218,8 +218,8 @@ void generate_xlnet_gemm_config(int batch_size,
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -236,16 +236,19 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -236,16 +236,19 @@ void generate_xlnet_gemm_config(int batch_size,
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -254,12 +257,15 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -254,12 +257,15 @@ void generate_xlnet_gemm_config(int batch_size,
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T, false>::Type; // using scaleT = typename ScaleTypeConverter<T, false>::Type;
using scaleT = typename ScaleTypeConverter<T, true>::Type;
scaleT alpha = (scaleT)1.0f; scaleT alpha = (scaleT)1.0f;
scaleT beta = (scaleT)0.0f; scaleT beta = (scaleT)0.0f;
...@@ -358,30 +364,31 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -358,30 +364,31 @@ void generate_xlnet_gemm_config(int batch_size,
const int ALGO_COMBINATIONS = 5000; const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size, // batch_size,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure( // printPerfStructure(
batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); // batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time = perfResults[0].time; // exec_time = perfResults[0].time;
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment