Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
9484fd1c
Commit
9484fd1c
authored
Dec 20, 2023
by
xiabo
Browse files
Adapt to 0.1.0
parent
477f2db8
Changes
56
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
2119 additions
and
2037 deletions
+2119
-2037
src/turbomind/utils/cublasMMWrapper.cc
src/turbomind/utils/cublasMMWrapper.cc
+192
-189
src/turbomind/utils/cublasMMWrapper.h
src/turbomind/utils/cublasMMWrapper.h
+14
-14
src/turbomind/utils/cuda_type_utils.cuh
src/turbomind/utils/cuda_type_utils.cuh
+22
-10
src/turbomind/utils/custom_ar_comm.cc
src/turbomind/utils/custom_ar_comm.cc
+7
-7
src/turbomind/utils/gemm.cc
src/turbomind/utils/gemm.cc
+96
-94
src/turbomind/utils/gemm.h
src/turbomind/utils/gemm.h
+4
-4
src/turbomind/utils/gemm_test/CMakeLists.txt
src/turbomind/utils/gemm_test/CMakeLists.txt
+45
-32
src/turbomind/utils/gemm_test/decoding_gemm_func.cc
src/turbomind/utils/gemm_test/decoding_gemm_func.cc
+48
-41
src/turbomind/utils/gemm_test/encoder_gemm_func.cc
src/turbomind/utils/gemm_test/encoder_gemm_func.cc
+40
-33
src/turbomind/utils/gemm_test/encoder_igemm_func.cc
src/turbomind/utils/gemm_test/encoder_igemm_func.cc
+711
-711
src/turbomind/utils/gemm_test/gemm_func.cc
src/turbomind/utils/gemm_test/gemm_func.cc
+708
-708
src/turbomind/utils/gemm_test/gpt_gemm_func.cc
src/turbomind/utils/gemm_test/gpt_gemm_func.cc
+71
-53
src/turbomind/utils/gemm_test/swin_gemm_func.cc
src/turbomind/utils/gemm_test/swin_gemm_func.cc
+40
-33
src/turbomind/utils/gemm_test/swin_igemm_func.cc
src/turbomind/utils/gemm_test/swin_igemm_func.cc
+17
-17
src/turbomind/utils/gemm_test/t5_gemm_func.cc
src/turbomind/utils/gemm_test/t5_gemm_func.cc
+64
-58
src/turbomind/utils/gemm_test/xlnet_gemm_func.cc
src/turbomind/utils/gemm_test/xlnet_gemm_func.cc
+40
-33
No files found.
src/turbomind/utils/cublasMMWrapper.cc
View file @
9484fd1c
...
@@ -185,124 +185,126 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
...
@@ -185,124 +185,126 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasLtMatmulAlgo_info
info
=
cublas_algo_map_
->
getAlgo
(
batch_count
,
m
,
n
,
k
,
getCublasDataType
(
Atype_
));
cublasLtMatmulAlgo_info
info
=
cublas_algo_map_
->
getAlgo
(
batch_count
,
m
,
n
,
k
,
getCublasDataType
(
Atype_
));
if
(
findAlgo
)
{
if
(
findAlgo
)
{
if
(
info
.
stages
!=
-
1
)
{
if
(
info
.
stages
!=
-
1
)
{
using_cublasLt
=
tru
e
;
using_cublasLt
=
fals
e
;
}
}
else
{
else
{
using_cublasLt
=
false
;
using_cublasLt
=
false
;
}
}
}
}
if
(
using_cublasLt
)
{
cublasLtMatmulDesc_t
operationDesc
=
NULL
;
cublasLtMatrixLayout_t
Adesc
=
NULL
,
Bdesc
=
NULL
,
Cdesc
=
NULL
;
cudaDataType_t
scaleType
;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t
computeType
;
#else
cudaDataType_t
computeType
;
#endif
if
(
is_fp16_computeType
)
{
#if (CUDART_VERSION >= 11000)
computeType
=
CUBLAS_COMPUTE_16F
;
#else
computeType
=
CUDA_R_16F
;
#endif
scaleType
=
CUDA_R_16F
;
}
else
{
#if (CUDART_VERSION >= 11000)
computeType
=
CUBLAS_COMPUTE_32F
;
#else
computeType
=
CUDA_R_32F
;
#endif
scaleType
=
CUDA_R_32F
;
}
// --------------------------------------
// Create descriptors for the original matrices
cublasLtMatrixLayoutCreate
(
&
Adesc
,
Atype_
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
Btype_
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
Ctype_
,
m
,
n
,
ldc
);
#if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate
(
&
operationDesc
,
computeType
,
scaleType
);
#else
cublasLtMatmulDescCreate
(
&
operationDesc
,
computeType
);
#endif
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
cublasOperation_t
));
// if (using_cublasLt) {
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
cublasOperation_t
));
// if (0) {
// cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatmulAlgo_t
algo
;
// cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
void
*
workSpace
=
cublas_workspace_
;
// cudaDataType_t scaleType;
int
workspaceSize
=
cublas_workspace_
==
NULL
?
0
:
CUBLAS_WORKSPACE_SIZE
;
// #if (CUDART_VERSION >= 11000)
if
(
findAlgo
)
{
// cublasComputeType_t computeType;
if
(
info
.
workspaceSize
>
workspaceSize
)
{
// #else
findAlgo
=
0
;
// cudaDataType_t computeType;
}
// #endif
else
{
cublasLtMatmulAlgoInit
(
// if (is_fp16_computeType) {
cublaslt_handle_
,
computeType
,
scaleType
,
Atype_
,
Btype_
,
Ctype_
,
Ctype_
,
info
.
algoId
,
&
algo
);
// #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute
(
// computeType = CUBLAS_COMPUTE_16F;
&
algo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
(
info
.
customOption
),
sizeof
(
info
.
customOption
));
// #else
cublasLtMatmulAlgoConfigSetAttribute
(
// computeType = CUDA_R_16F;
&
algo
,
CUBLASLT_ALGO_CONFIG_TILE_ID
,
&
(
info
.
tile
),
sizeof
(
info
.
tile
));
// #endif
cublasLtMatmulAlgoConfigSetAttribute
(
// scaleType = CUDA_R_16F;
&
algo
,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
&
(
info
.
splitK_val
),
sizeof
(
info
.
splitK_val
));
// }
cublasLtMatmulAlgoConfigSetAttribute
(
// else {
&
algo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
(
info
.
swizzle
),
sizeof
(
info
.
swizzle
));
// #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
// computeType = CUBLAS_COMPUTE_32F;
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
// #else
&
(
info
.
reductionScheme
),
// computeType = CUDA_R_32F;
sizeof
(
info
.
reductionScheme
));
// #endif
// scaleType = CUDA_R_32F;
#if (CUDART_VERSION >= 11000)
// }
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
(
info
.
stages
),
sizeof
(
info
.
stages
));
// // --------------------------------------
#endif
// // Create descriptors for the original matrices
// cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
// cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
cublasLtMatmulAlgoConfigSetAttribute
(
// cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
&
algo
,
CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID
,
&
(
info
.
inner_shapeId
),
sizeof
(
info
.
inner_shapeId
));
// #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
// cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID
,
// #else
&
(
info
.
cluster_shapeId
),
// cublasLtMatmulDescCreate(&operationDesc, computeType);
sizeof
(
info
.
cluster_shapeId
));
// #endif
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
cublasLtMatmulAlgoConfigSetAttribute
(
// cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
&
algo
,
CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID
,
&
(
info
.
mma_shapeId
),
sizeof
(
info
.
mma_shapeId
));
// cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID
,
&
(
info
.
cga_shapeId
),
sizeof
(
info
.
cga_shapeId
));
// cublasLtMatmulAlgo_t algo;
cublasLtMatmulAlgoConfigSetAttribute
(
// void* workSpace = cublas_workspace_;
&
algo
,
CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE
,
&
(
info
.
sche_mode
),
sizeof
(
info
.
sche_mode
));
// int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
#endif
// if (findAlgo) {
}
// if (info.workspaceSize > workspaceSize) {
}
// findAlgo = 0;
// }
cublasLtMatmul
(
cublaslt_handle_
,
// else {
operationDesc
,
// cublasLtMatmulAlgoInit(
alpha
,
// cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo);
A
,
// cublasLtMatmulAlgoConfigSetAttribute(
Adesc
,
// &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
B
,
// cublasLtMatmulAlgoConfigSetAttribute(
Bdesc
,
// &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
beta
,
// cublasLtMatmulAlgoConfigSetAttribute(
C
,
// &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
Cdesc
,
// cublasLtMatmulAlgoConfigSetAttribute(
C
,
// &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
Cdesc
,
// cublasLtMatmulAlgoConfigSetAttribute(&algo,
(
findAlgo
==
1
?
(
&
algo
)
:
NULL
),
// CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
workSpace
,
// &(info.reductionScheme),
workspaceSize
,
// sizeof(info.reductionScheme));
stream_
);
// // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescDestroy
(
operationDesc
);
// // cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatrixLayoutDestroy
(
Adesc
);
// // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
cublasLtMatrixLayoutDestroy
(
Bdesc
);
// // #endif
cublasLtMatrixLayoutDestroy
(
Cdesc
);
sync_check_cuda_error
();
// #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
}
// cublasLtMatmulAlgoConfigSetAttribute(
else
{
// &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(&algo,
// CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID,
// &(info.cluster_shapeId),
// sizeof(info.cluster_shapeId));
// #elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode));
// #endif
// }
// }
// // cublasLtMatmul(cublaslt_handle_,
// // operationDesc,
// // alpha,
// // A,
// // Adesc,
// // B,
// // Bdesc,
// // beta,
// // C,
// // Cdesc,
// // C,
// // Cdesc,
// // (findAlgo == 1 ? (&algo) : NULL),
// // workSpace,
// // workspaceSize,
// // stream_);
// cublasLtMatmulDescDestroy(operationDesc);
// cublasLtMatrixLayoutDestroy(Adesc);
// cublasLtMatrixLayoutDestroy(Bdesc);
// cublasLtMatrixLayoutDestroy(Cdesc);
// sync_check_cuda_error();
// }
// else {
int
cublasAlgo
=
info
.
algoId
;
int
cublasAlgo
=
info
.
algoId
;
check_cuda_error
(
cublasGemmEx
(
cublas_handle_
,
check_cuda_error
(
cublasGemmEx
(
cublas_handle_
,
transa
,
transa
,
...
@@ -324,7 +326,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
...
@@ -324,7 +326,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
computeType_
,
computeType_
,
static_cast
<
cublasGemmAlgo_t
>
(
cublasAlgo
)));
static_cast
<
cublasGemmAlgo_t
>
(
cublasAlgo
)));
sync_check_cuda_error
();
sync_check_cuda_error
();
}
//
}
mu_
->
unlock
();
mu_
->
unlock
();
}
}
...
@@ -341,7 +343,7 @@ void cublasMMWrapper::setFP16GemmConfig()
...
@@ -341,7 +343,7 @@ void cublasMMWrapper::setFP16GemmConfig()
Atype_
=
CUDA_R_16F
;
Atype_
=
CUDA_R_16F
;
Btype_
=
CUDA_R_16F
;
Btype_
=
CUDA_R_16F
;
Ctype_
=
CUDA_R_16F
;
Ctype_
=
CUDA_R_16F
;
computeType_
=
CUDA_R_
32
F
;
computeType_
=
CUDA_R_
16
F
;
}
}
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
...
@@ -381,81 +383,81 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
...
@@ -381,81 +383,81 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
return
FLOAT_DATATYPE
;
return
FLOAT_DATATYPE
;
}
}
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
// input, weight, output are row-major
//
// input, weight, output are row-major
// only works for cublas 11.x
//
// only works for cublas 11.x
void
cublasMMWrapper
::
Gemm
(
cublasOperation_t
transa
,
//
void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasOperation_t
transb
,
//
cublasOperation_t transb,
const
int
m
,
//
const int m,
const
int
n
,
//
const int n,
const
int
k
,
//
const int k,
const
void
*
A
,
//
const void* A,
const
int
lda
,
//
const int lda,
const
void
*
B
,
//
const void* B,
const
int
ldb
,
//
const int ldb,
const
void
*
bias
,
//
const void* bias,
void
*
C
,
//
void* C,
const
int
ldc
)
//
const int ldc)
{
//
{
TM_LOG_DEBUG
(
__PRETTY_FUNCTION__
);
//
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
cudaDataType_t
Atype
,
Btype
,
Ctype
;
//
cudaDataType_t Atype, Btype, Ctype;
cublasComputeType_t
computeType
;
//
cublasComputeType_t computeType;
cudaDataType_t
scaleType
;
//
cudaDataType_t scaleType;
float
alpha_float
=
1.0
f
;
//
float alpha_float = 1.0f;
float
beta_float
=
0.0
f
;
//
float beta_float = 0.0f;
half
alpha_half
=
half
(
1.0
f
);
//
half alpha_half = half(1.0f);
half
beta_half
=
half
(
0.0
f
);
//
half beta_half = half(0.0f);
void
*
alpha
,
*
beta
;
//
void * alpha, *beta;
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
//
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
if
(
Atype_
==
CUDA_R_32F
)
{
//
if (Atype_ == CUDA_R_32F) {
computeType
=
CUBLAS_COMPUTE_32F_FAST_TF32
;
//
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype
=
CUDA_R_32F
;
//
Atype = CUDA_R_32F;
Btype
=
CUDA_R_32F
;
//
Btype = CUDA_R_32F;
Ctype
=
CUDA_R_32F
;
//
Ctype = CUDA_R_32F;
scaleType
=
CUDA_R_32F
;
//
scaleType = CUDA_R_32F;
alpha
=
&
alpha_float
;
//
alpha = &alpha_float;
beta
=
&
beta_float
;
//
beta = &beta_float;
}
//
}
else
if
(
Atype_
==
CUDA_R_16BF
)
{
//
else if (Atype_ == CUDA_R_16BF) {
computeType
=
CUBLAS_COMPUTE_32F_FAST_TF32
;
//
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype
=
CUDA_R_16BF
;
//
Atype = CUDA_R_16BF;
Btype
=
CUDA_R_16BF
;
//
Btype = CUDA_R_16BF;
Ctype
=
CUDA_R_16BF
;
//
Ctype = CUDA_R_16BF;
scaleType
=
CUDA_R_32F
;
//
scaleType = CUDA_R_32F;
alpha
=
&
alpha_float
;
//
alpha = &alpha_float;
beta
=
&
beta_float
;
//
beta = &beta_float;
}
//
}
else
{
//
else {
computeType
=
CUBLAS_COMPUTE_16F
;
//
computeType = CUBLAS_COMPUTE_16F;
Atype
=
CUDA_R_16F
;
//
Atype = CUDA_R_16F;
Btype
=
CUDA_R_16F
;
//
Btype = CUDA_R_16F;
Ctype
=
CUDA_R_16F
;
//
Ctype = CUDA_R_16F;
scaleType
=
CUDA_R_16F
;
//
scaleType = CUDA_R_16F;
alpha
=
&
alpha_half
;
//
alpha = &alpha_half;
beta
=
&
beta_half
;
//
beta = &beta_half;
}
//
}
cublasLtMatmulDesc_t
operationDesc
=
NULL
;
//
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t
Adesc
=
NULL
,
Bdesc
=
NULL
,
Cdesc
=
NULL
;
//
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasLtEpilogue_t
epi
=
CUBLASLT_EPILOGUE_BIAS
;
//
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
cublasLtMatrixLayoutCreate
(
&
Adesc
,
Atype
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
transa
==
CUBLAS_OP_N
)
?
k
:
m
,
lda
);
//
cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda);
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
Btype
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
transb
==
CUBLAS_OP_N
)
?
n
:
k
,
ldb
);
//
cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb);
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
Ctype
,
m
,
n
,
ldc
);
//
cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc);
cublasLtMatmulDescCreate
(
&
operationDesc
,
computeType
,
scaleType
);
//
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
cublasOperation_t
));
//
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
cublasOperation_t
));
//
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epi
,
sizeof
(
cublasLtEpilogue_t
));
//
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
const
void
*
));
//
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
check_cuda_error
(
cublasLtMatmul
(
//
//
check_cuda_error(cublasLtMatmul(
cublaslt_handle_
,
operationDesc
,
alpha
,
A
,
Adesc
,
B
,
Bdesc
,
beta
,
C
,
Cdesc
,
C
,
Cdesc
,
NULL
,
NULL
,
0
,
stream_
));
//
//
cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_));
cublasLtMatrixLayoutDestroy
(
Adesc
);
//
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy
(
Bdesc
);
//
cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy
(
Cdesc
);
//
cublasLtMatrixLayoutDestroy(Cdesc);
cublasLtMatmulDescDestroy
(
operationDesc
);
//
cublasLtMatmulDescDestroy(operationDesc);
}
//
}
#endif
//
#endif
void
cublasMMWrapper
::
setStream
(
cudaStream_t
stream
)
void
cublasMMWrapper
::
setStream
(
cudaStream_t
stream
)
{
{
stream_
=
stream
;
stream_
=
stream
;
...
@@ -985,7 +987,8 @@ void cublasMMWrapper::_Int8Gemm(const int m,
...
@@ -985,7 +987,8 @@ void cublasMMWrapper::_Int8Gemm(const int m,
* - 0: int8 * int8 -> int32 -> int8
* - 0: int8 * int8 -> int32 -> int8
* - 1: int8 * int8 -> int32 -> int32
* - 1: int8 * int8 -> int32 -> int32
*/
*/
#if (CUBLAS_VERSION) <= 11601
// #if (CUBLAS_VERSION) <= 11601
#if 1
FT_CHECK_WITH_INFO
(
false
,
"CUBLAS version too low."
);
FT_CHECK_WITH_INFO
(
false
,
"CUBLAS version too low."
);
#else
#else
...
...
src/turbomind/utils/cublasMMWrapper.h
View file @
9484fd1c
...
@@ -207,20 +207,20 @@ public:
...
@@ -207,20 +207,20 @@ public:
CublasDataType
getCublasDataType
(
cudaDataType_t
data_type
);
CublasDataType
getCublasDataType
(
cudaDataType_t
data_type
);
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
void
Gemm
(
cublasOperation_t
transa
,
//
void Gemm(cublasOperation_t transa,
cublasOperation_t
transb
,
//
cublasOperation_t transb,
const
int
m
,
//
const int m,
const
int
n
,
//
const int n,
const
int
k
,
//
const int k,
const
void
*
A
,
//
const void* A,
const
int
lda
,
//
const int lda,
const
void
*
B
,
//
const void* B,
const
int
ldb
,
//
const int ldb,
const
void
*
bias
,
//
const void* bias,
void
*
C
,
//
void* C,
const
int
ldc
);
//
const int ldc);
#endif
//
#endif
void
stridedBatchedGemm
(
cublasOperation_t
transa
,
void
stridedBatchedGemm
(
cublasOperation_t
transa
,
cublasOperation_t
transb
,
cublasOperation_t
transb
,
...
...
src/turbomind/utils/cuda_type_utils.cuh
View file @
9484fd1c
...
@@ -322,7 +322,7 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val)
...
@@ -322,7 +322,7 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val)
int16_t
int16_in
;
int16_t
int16_in
;
};
};
fp16
=
val
;
fp16
=
val
;
asm
volatile
(
"cvt.rni.sat.s8.f16 %0, %1;"
:
"=h"
(
int16
)
:
"h"
(
int16_in
));
//
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return
int8
[
0
];
return
int8
[
0
];
}
}
...
@@ -333,20 +333,31 @@ __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
...
@@ -333,20 +333,31 @@ __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
int8_t
int8
[
2
];
int8_t
int8
[
2
];
int16_t
int16
;
int16_t
int16
;
};
};
int8
[
0
]
=
cuda_cast
<
int8_t
>
(
val
.
x
);
// int8[0] = cuda_cast<int8_t>(val.x);
int8
[
1
]
=
cuda_cast
<
int8_t
>
(
val
.
y
);
// int8[1] = cuda_cast<int8_t>(val.y);
int8
[
0
]
=
cuda_cast
<
int8_t
>
((
val
.
data
[
0
]));
int8
[
1
]
=
cuda_cast
<
int8_t
>
((
val
.
data
[
1
]));
return
int16
;
return
int16
;
}
}
template
<
>
template
<
>
__device__
inline
int8_t
cuda_cast
<
int8_t
,
float
>
(
float
val
)
__device__
inline
int8_t
cuda_cast
<
int8_t
,
float
>
(
float
val
)
{
{
union
{
// union {
int8_t
int8
[
2
];
// int8_t int8[2];
int16_t
int16
;
// int16_t int16;
};
// };
asm
volatile
(
"cvt.rni.sat.s8.f32 %0, %1;"
:
"=h"
(
int16
)
:
"f"
(
val
));
// asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return
int8
[
0
];
// return int8[0];
int8_t
dst
;
if
(
val
>=
128
){
dst
=
127
;
}
else
if
(
val
<
-
128
){
dst
=
-
128
;
}
else
{
dst
=
static_cast
<
int8_t
>
(
val
);
}
return
dst
;
}
}
template
<
>
template
<
>
...
@@ -528,7 +539,8 @@ __device__ inline To cuda_max(Ti val)
...
@@ -528,7 +539,8 @@ __device__ inline To cuda_max(Ti val)
template
<
>
template
<
>
__device__
inline
half
cuda_max
(
half2
val
)
__device__
inline
half
cuda_max
(
half2
val
)
{
{
return
(
val
.
x
>
val
.
y
)
?
val
.
x
:
val
.
y
;
// return (val.x > val.y) ? val.x : val.y;
return
(
val
.
data
[
0
]
>
val
.
data
[
1
])
?
val
.
data
[
0
]
:
val
.
data
[
1
];
}
}
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
template
<
>
template
<
>
...
...
src/turbomind/utils/custom_ar_comm.cc
View file @
9484fd1c
...
@@ -152,17 +152,17 @@ void initCustomAllReduceComm(std::vector<std::shared_ptr<AbstractCustomComm>>* c
...
@@ -152,17 +152,17 @@ void initCustomAllReduceComm(std::vector<std::shared_ptr<AbstractCustomComm>>* c
return
;
return
;
}
}
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
//
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
for
(
size_t
i
=
0
;
i
<
rank_size
;
i
++
)
{
//
for (size_t i = 0; i < rank_size; i++) {
custom_all_reduce_comms
->
push_back
(
std
::
make_shared
<
CustomAllReduceComm
<
T
>>
(
rank_size
,
i
));
//
custom_all_reduce_comms->push_back(std::make_shared<CustomAllReduceComm<T>>(rank_size, i));
}
//
}
custom_all_reduce_comms
->
at
(
0
)
->
allocateAndExchangePeerAccessPointer
(
custom_all_reduce_comms
);
//
custom_all_reduce_comms->at(0)->allocateAndExchangePeerAccessPointer(custom_all_reduce_comms);
#else
//
#else
TM_LOG_WARNING
(
"Custom All Reduce is not supported before CUDA 11.2. Using NCCL as Comm."
);
TM_LOG_WARNING
(
"Custom All Reduce is not supported before CUDA 11.2. Using NCCL as Comm."
);
for
(
size_t
i
=
0
;
i
<
rank_size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
rank_size
;
i
++
)
{
custom_all_reduce_comms
->
push_back
(
nullptr
);
custom_all_reduce_comms
->
push_back
(
nullptr
);
}
}
#endif
//
#endif
}
}
// Template instantiation
// Template instantiation
...
...
src/turbomind/utils/gemm.cc
View file @
9484fd1c
...
@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file)
...
@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file)
stream_
=
stream
;
stream_
=
stream
;
mutex_
=
new
std
::
mutex
();
// mutex per process
mutex_
=
new
std
::
mutex
();
// mutex per process
check_cuda_error
(
cublasCreate
(
&
cublas_handle_
));
check_cuda_error
(
cublasCreate
(
&
cublas_handle_
));
check_cuda_error
(
cublasLtCreate
(
&
cublaslt_handle_
));
//
check_cuda_error(cublasLtCreate(&cublaslt_handle_));
check_cuda_error
(
cublasSetStream
(
cublas_handle_
,
stream
));
check_cuda_error
(
cublasSetStream
(
cublas_handle_
,
stream
));
if
(
allocator_
!=
nullptr
)
{
if
(
allocator_
!=
nullptr
)
{
...
@@ -41,7 +41,7 @@ Gemm::~Gemm()
...
@@ -41,7 +41,7 @@ Gemm::~Gemm()
allocator_
->
free
((
void
**
)(
&
workspace_
));
allocator_
->
free
((
void
**
)(
&
workspace_
));
allocator_
=
nullptr
;
allocator_
=
nullptr
;
}
}
cublasLtDestroy
(
cublaslt_handle_
);
//
cublasLtDestroy(cublaslt_handle_);
cublasDestroy
(
cublas_handle_
);
cublasDestroy
(
cublas_handle_
);
delete
cublas_algo_map_
;
delete
cublas_algo_map_
;
delete
mutex_
;
delete
mutex_
;
...
@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa,
...
@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa,
mutex_
->
lock
();
mutex_
->
lock
();
// Use cublas as default in FP32 and cublasLt as default in FP16
// Use cublas as default in FP32 and cublasLt as default in FP16
bool
is_fp16_compute_type
=
compute_type_
==
TYPE_FP16
;
bool
is_fp16_compute_type
=
compute_type_
==
TYPE_FP16
;
bool
using_cublasLt
=
Atype
==
TYPE_FP16
;
// bool using_cublasLt = Atype == TYPE_FP16;
bool
using_cublasLt
=
(
Atype
==
TYPE_FP16
)
?
false
:
false
;
int
batch_count
=
1
;
int
batch_count
=
1
;
half
h_alpha
=
(
half
)
alpha
;
half
h_alpha
=
(
half
)
alpha
;
...
@@ -267,82 +268,83 @@ void Gemm::gemm(const GemmOp transa,
...
@@ -267,82 +268,83 @@ void Gemm::gemm(const GemmOp transa,
using_cublasLt
=
(
info
.
stages
!=
-
1
);
using_cublasLt
=
(
info
.
stages
!=
-
1
);
}
}
if
(
using_cublasLt
)
{
// if (using_cublasLt) {
const
size_t
a_rows
=
(
a_op
==
getCublasOperation
(
GEMM_OP_N
))
?
_m
:
k
;
// if(0) {
const
size_t
a_cols
=
(
a_op
==
getCublasOperation
(
GEMM_OP_N
))
?
k
:
_m
;
// const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k;
const
size_t
b_rows
=
(
b_op
==
getCublasOperation
(
GEMM_OP_N
))
?
k
:
_n
;
// const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m;
const
size_t
b_cols
=
(
b_op
==
getCublasOperation
(
GEMM_OP_N
))
?
_n
:
k
;
// const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n;
// const size_t b_cols = (b_op == getCublasOperation(GEMM_OP_N)) ? _n : k;
cublasLtMatmulDesc_t
matmul_desc
=
NULL
;
cublasLtMatrixLayout_t
a_desc
=
NULL
,
b_desc
=
NULL
,
c_desc
=
NULL
;
// cublasLtMatmulDesc_t matmul_desc = NULL;
cudaDataType_t
scale_type
=
getCublasDataType
(
compute_type_
);
// cublasLtMatrixLayout_t a_desc = NULL, b_desc = NULL, c_desc = NULL;
auto
compute_type
=
getCublasComputeType
(
compute_type_
);
// cudaDataType_t scale_type = getCublasDataType(compute_type_);
// auto compute_type = getCublasComputeType(compute_type_);
// --------------------------------------
// Create descriptors for the original matrices
// // --------------------------------------
cublasLtMatrixLayoutCreate
(
&
a_desc
,
a_type
,
a_rows
,
a_cols
,
_lda
);
// // Create descriptors for the original matrices
cublasLtMatrixLayoutCreate
(
&
b_desc
,
b_type
,
b_rows
,
b_cols
,
_ldb
);
// cublasLtMatrixLayoutCreate(&a_desc, a_type, a_rows, a_cols, _lda);
cublasLtMatrixLayoutCreate
(
&
c_desc
,
c_type
,
_m
,
_n
,
ldc
);
// cublasLtMatrixLayoutCreate(&b_desc, b_type, b_rows, b_cols, _ldb);
#if (CUDART_VERSION >= 11000)
// cublasLtMatrixLayoutCreate(&c_desc, c_type, _m, _n, ldc);
cublasLtMatmulDescCreate
(
&
matmul_desc
,
compute_type
,
scale_type
);
// #if (CUDART_VERSION >= 11000)
#else
// cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_type);
cublasLtMatmulDescCreate
(
&
matmul_desc
,
compute_type
);
// #else
#endif
// cublasLtMatmulDescCreate(&matmul_desc, compute_type);
// #endif
cublasLtMatmulDescSetAttribute
(
matmul_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
a_op
,
sizeof
(
cublasOperation_t
));
cublasLtMatmulDescSetAttribute
(
matmul_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
b_op
,
sizeof
(
cublasOperation_t
));
// cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &a_op, sizeof(cublasOperation_t));
// cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t));
cublasLtMatmulAlgo_t
algo
;
void
*
workspace
=
workspace_
;
// cublasLtMatmulAlgo_t algo;
int
workspace_size
=
workspace_
==
nullptr
?
0
:
CUBLAS_WORKSPACE_SIZE
;
// void* workspace = workspace_;
if
(
findAlgo
)
{
// int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE;
if
(
info
.
workspaceSize
>
workspace_size
)
{
// if (findAlgo) {
findAlgo
=
0
;
// if (info.workspaceSize > workspace_size) {
}
// findAlgo = 0;
else
{
// }
cublasLtMatmulAlgoInit
(
// else {
cublaslt_handle_
,
compute_type
,
scale_type
,
a_type
,
b_type
,
c_type
,
c_type
,
info
.
algoId
,
&
algo
);
// cublasLtMatmulAlgoInit(
cublasLtMatmulAlgoConfigSetAttribute
(
// cublaslt_handle_, compute_type, scale_type, a_type, b_type, c_type, c_type, info.algoId, &algo);
&
algo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
(
info
.
customOption
),
sizeof
(
info
.
customOption
));
// cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgoConfigSetAttribute
(
// &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
&
algo
,
CUBLASLT_ALGO_CONFIG_TILE_ID
,
&
(
info
.
tile
),
sizeof
(
info
.
tile
));
// cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgoConfigSetAttribute
(
// &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
&
algo
,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
&
(
info
.
splitK_val
),
sizeof
(
info
.
splitK_val
));
// cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgoConfigSetAttribute
(
// &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
&
algo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
(
info
.
swizzle
),
sizeof
(
info
.
swizzle
));
// cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgoConfigSetAttribute
(
// &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
&
algo
,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
&
(
info
.
reductionScheme
),
sizeof
(
int
));
// cublasLtMatmulAlgoConfigSetAttribute(
#if (CUDART_VERSION >= 11000)
// &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int));
cublasLtMatmulAlgoConfigSetAttribute
(
// #if (CUDART_VERSION >= 11000)
&
algo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
(
info
.
stages
),
sizeof
(
info
.
stages
));
// cublasLtMatmulAlgoConfigSetAttribute(
#endif
// &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
}
// #endif
}
// }
// }
cublasLtMatmul
(
cublaslt_handle_
,
matmul_desc
,
// cublasLtMatmul(cublaslt_handle_,
alpha_ptr
,
// matmul_desc,
a_data_ptr
,
// alpha_ptr,
a_desc
,
// a_data_ptr,
b_data_ptr
,
// a_desc,
b_desc
,
// b_data_ptr,
beta_ptr
,
// b_desc,
C
,
// beta_ptr,
c_desc
,
// C,
C
,
// c_desc,
c_desc
,
// C,
(
findAlgo
==
1
?
(
&
algo
)
:
NULL
),
// c_desc,
workspace
,
// (findAlgo == 1 ? (&algo) : NULL),
workspace_size
,
// workspace,
stream_
);
// workspace_size,
// stream_);
cublasLtMatmulDescDestroy
(
matmul_desc
);
cublasLtMatrixLayoutDestroy
(
a_desc
);
// cublasLtMatmulDescDestroy(matmul_desc);
cublasLtMatrixLayoutDestroy
(
b_desc
);
// cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatrixLayoutDestroy
(
c_desc
);
// cublasLtMatrixLayoutDestroy(b_desc);
sync_check_cuda_error
();
// cublasLtMatrixLayoutDestroy(c_desc);
}
// sync_check_cuda_error();
else
{
// }
// else {
cudaDataType_t
compute_type
=
getCublasDataType
(
compute_type_
);
cudaDataType_t
compute_type
=
getCublasDataType
(
compute_type_
);
int
cublas_algo
=
info
.
algoId
;
int
cublas_algo
=
info
.
algoId
;
check_cuda_error
(
cublasGemmEx
(
cublas_handle_
,
check_cuda_error
(
cublasGemmEx
(
cublas_handle_
,
...
@@ -365,7 +367,7 @@ void Gemm::gemm(const GemmOp transa,
...
@@ -365,7 +367,7 @@ void Gemm::gemm(const GemmOp transa,
compute_type
,
compute_type
,
static_cast
<
cublasGemmAlgo_t
>
(
cublas_algo
)));
static_cast
<
cublasGemmAlgo_t
>
(
cublas_algo
)));
sync_check_cuda_error
();
sync_check_cuda_error
();
}
//
}
mutex_
->
unlock
();
mutex_
->
unlock
();
}
}
...
@@ -1033,19 +1035,19 @@ cudaDataType_t getCublasDataType(DataType dtype)
...
@@ -1033,19 +1035,19 @@ cudaDataType_t getCublasDataType(DataType dtype)
}
}
}
}
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
cublasComputeType_t
getCublasComputeType
(
DataType
ctype
)
//
cublasComputeType_t getCublasComputeType(DataType ctype)
{
//
{
switch
(
ctype
)
{
//
switch (ctype) {
case
TYPE_FP16
:
//
case TYPE_FP16:
return
CUBLAS_COMPUTE_16F
;
//
return CUBLAS_COMPUTE_16F;
case
TYPE_FP32
:
//
case TYPE_FP32:
return
CUBLAS_COMPUTE_32F
;
//
return CUBLAS_COMPUTE_32F;
default:
//
default:
throw
GemmNotSupportedException
(
"Not supported cublas compute type."
);
//
throw GemmNotSupportedException("Not supported cublas compute type.");
}
//
}
}
//
}
#else
//
#else
cudaDataType_t
getCublasComputeType
(
DataType
ctype
)
cudaDataType_t
getCublasComputeType
(
DataType
ctype
)
{
{
switch
(
ctype
)
{
switch
(
ctype
)
{
...
@@ -1057,7 +1059,7 @@ cudaDataType_t getCublasComputeType(DataType ctype)
...
@@ -1057,7 +1059,7 @@ cudaDataType_t getCublasComputeType(DataType ctype)
throw
GemmNotSupportedException
(
"Not supported cublas compute type."
);
throw
GemmNotSupportedException
(
"Not supported cublas compute type."
);
}
}
}
}
#endif
//
#endif
cublasOperation_t
getCublasOperation
(
GemmOp
op
)
cublasOperation_t
getCublasOperation
(
GemmOp
op
)
{
{
...
...
src/turbomind/utils/gemm.h
View file @
9484fd1c
...
@@ -622,11 +622,11 @@ std::shared_ptr<Gemm>
...
@@ -622,11 +622,11 @@ std::shared_ptr<Gemm>
createGemm
(
IAllocator
*
allocator
,
cudaStream_t
stream
,
bool
sparse
=
false
,
bool
quantized
=
false
);
createGemm
(
IAllocator
*
allocator
,
cudaStream_t
stream
,
bool
sparse
=
false
,
bool
quantized
=
false
);
cudaDataType_t
getCublasDataType
(
DataType
dtype
);
cudaDataType_t
getCublasDataType
(
DataType
dtype
);
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
cublasComputeType_t
getCublasComputeType
(
DataType
dtype
);
//
cublasComputeType_t getCublasComputeType(DataType dtype);
#else
//
#else
cudaDataType_t
getCublasComputeType
(
DataType
dtype
);
cudaDataType_t
getCublasComputeType
(
DataType
dtype
);
#endif
//
#endif
cublasOperation_t
getCublasOperation
(
GemmOp
op
);
cublasOperation_t
getCublasOperation
(
GemmOp
op
);
std
::
string
getGemmOpString
(
const
GemmOp
&
op
);
std
::
string
getGemmOpString
(
const
GemmOp
&
op
);
...
...
src/turbomind/utils/gemm_test/CMakeLists.txt
View file @
9484fd1c
...
@@ -13,7 +13,8 @@
...
@@ -13,7 +13,8 @@
# limitations under the License.
# limitations under the License.
cmake_minimum_required
(
VERSION 3.8
)
cmake_minimum_required
(
VERSION 3.8
)
find_package
(
CUDAToolkit REQUIRED
)
#find_package(CUDAToolkit REQUIRED)
find_package
(
CUDA REQUIRED
)
set
(
gemm_func_files
set
(
gemm_func_files
gemm_func.cc
gemm_func.cc
...
@@ -51,59 +52,71 @@ set(swin_gemm_func_files
...
@@ -51,59 +52,71 @@ set(swin_gemm_func_files
swin_gemm_func.cc
swin_gemm_func.cc
)
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fPIC"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-fPIC"
)
add_library
(
gemm_func STATIC
${
gemm_func_files
}
)
add_library
(
gemm_func STATIC
${
gemm_func_files
}
)
target_link_libraries
(
gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart cuda_utils logger
)
#target_link_libraries(gemm_func PUBLIC cublas cublasLt cudart cuda_utils logger)
set_property
(
TARGET gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
target_link_libraries
(
gemm_func PUBLIC cublas cudart cuda_utils logger
)
set_property
(
TARGET gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#set_property(TARGET gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library
(
encoder_gemm_func STATIC
${
encoder_gemm_func_files
}
)
add_library
(
encoder_gemm_func STATIC
${
encoder_gemm_func_files
}
)
target_link_libraries
(
encoder_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger
)
#target_link_libraries(encoder_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries
(
encoder_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger
)
if
(
SPARSITY_SUPPORT
)
if
(
SPARSITY_SUPPORT
)
target_link_libraries
(
encoder_gemm_func PUBLIC
CUDA::
cusparse -lcusparseLt
)
target_link_libraries
(
encoder_gemm_func PUBLIC cusparse -lcusparseLt
)
endif
()
endif
()
set_property
(
TARGET encoder_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
#
set_property(TARGET encoder_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property
(
TARGET encoder_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#
set_property(TARGET encoder_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library
(
encoder_igemm_func STATIC
${
encoder_igemm_func_files
}
)
add_library
(
encoder_igemm_func STATIC
${
encoder_igemm_func_files
}
)
target_link_libraries
(
encoder_igemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart cuda_utils logger
)
#target_link_libraries(encoder_igemm_func PUBLIC cublas cublasLt cudart cuda_utils logger)
target_link_libraries
(
encoder_igemm_func PUBLIC cublas cudart cuda_utils logger
)
if
(
SPARSITY_SUPPORT
)
if
(
SPARSITY_SUPPORT
)
target_link_libraries
(
encoder_igemm_func PUBLIC
CUDA::
cusparse -lcusparseLt
)
target_link_libraries
(
encoder_igemm_func PUBLIC cusparse -lcusparseLt
)
endif
()
endif
()
set_property
(
TARGET encoder_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
#
set_property(TARGET encoder_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property
(
TARGET encoder_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#
set_property(TARGET encoder_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library
(
decoding_gemm_func STATIC
${
decoding_gemm_func_files
}
)
add_library
(
decoding_gemm_func STATIC
${
decoding_gemm_func_files
}
)
target_link_libraries
(
decoding_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger
)
#target_link_libraries(decoding_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property
(
TARGET decoding_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
target_link_libraries
(
decoding_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger
)
set_property
(
TARGET decoding_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#set_property(TARGET decoding_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET decoding_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library
(
gpt_gemm_func STATIC
${
gpt_gemm_func_files
}
)
add_library
(
gpt_gemm_func STATIC
${
gpt_gemm_func_files
}
)
target_link_libraries
(
gpt_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger
)
#target_link_libraries(gpt_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries
(
gpt_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger
)
if
(
SPARSITY_SUPPORT
)
if
(
SPARSITY_SUPPORT
)
target_link_libraries
(
gpt_gemm_func PUBLIC
CUDA::
cusparse -lcusparseLt
)
target_link_libraries
(
gpt_gemm_func PUBLIC cusparse -lcusparseLt
)
endif
()
endif
()
set_property
(
TARGET gpt_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
#
set_property(TARGET gpt_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property
(
TARGET gpt_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#
set_property(TARGET gpt_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library
(
xlnet_gemm_func STATIC
${
xlnet_gemm_func_files
}
)
add_library
(
xlnet_gemm_func STATIC
${
xlnet_gemm_func_files
}
)
target_link_libraries
(
xlnet_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger
)
#target_link_libraries(xlnet_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property
(
TARGET xlnet_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
target_link_libraries
(
xlnet_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger
)
set_property
(
TARGET xlnet_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#set_property(TARGET xlnet_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET xlnet_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library
(
t5_gemm_func STATIC
${
t5_gemm_func_files
}
)
add_library
(
t5_gemm_func STATIC
${
t5_gemm_func_files
}
)
target_link_libraries
(
t5_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger
)
#target_link_libraries(t5_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries
(
t5_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger
)
if
(
SPARSITY_SUPPORT
)
if
(
SPARSITY_SUPPORT
)
target_link_libraries
(
t5_gemm_func PUBLIC
CUDA::
cusparse -lcusparseLt
)
target_link_libraries
(
t5_gemm_func PUBLIC cusparse -lcusparseLt
)
endif
()
endif
()
set_property
(
TARGET t5_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
#
set_property(TARGET t5_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property
(
TARGET t5_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#
set_property(TARGET t5_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library
(
swin_igemm_func STATIC
${
swin_igemm_func_files
}
)
add_library
(
swin_igemm_func STATIC
${
swin_igemm_func_files
}
)
target_link_libraries
(
swin_igemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func encoder_igemm_func cuda_utils logger
)
#target_link_libraries(swin_igemm_func PUBLIC cublas cublasLt cudart gemm_func encoder_igemm_func cuda_utils logger)
set_property
(
TARGET swin_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
target_link_libraries
(
swin_igemm_func PUBLIC cublas cudart gemm_func encoder_igemm_func cuda_utils logger
)
set_property
(
TARGET swin_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#set_property(TARGET swin_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET swin_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library
(
swin_gemm_func STATIC
${
swin_gemm_func_files
}
)
add_library
(
swin_gemm_func STATIC
${
swin_gemm_func_files
}
)
target_link_libraries
(
swin_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger
)
#target_link_libraries(swin_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property
(
TARGET swin_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON
)
target_link_libraries
(
swin_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger
)
set_property
(
TARGET swin_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
#set_property(TARGET swin_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET swin_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
src/turbomind/utils/gemm_test/decoding_gemm_func.cc
View file @
9484fd1c
...
@@ -130,8 +130,8 @@ void generate_decoding_gemm_config(int batch_size,
...
@@ -130,8 +130,8 @@ void generate_decoding_gemm_config(int batch_size,
cublasHandle_t
cublas_handle
;
cublasHandle_t
cublas_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
cublasLtHandle_t
ltHandle
;
//
cublasLtHandle_t ltHandle;
check_cuda_error
(
cublasLtCreate
(
&
ltHandle
));
//
check_cuda_error(cublasLtCreate(<Handle));
cudaDataType_t
AType
;
cudaDataType_t
AType
;
cudaDataType_t
BType
;
cudaDataType_t
BType
;
...
@@ -148,16 +148,19 @@ void generate_decoding_gemm_config(int batch_size,
...
@@ -148,16 +148,19 @@ void generate_decoding_gemm_config(int batch_size,
CType
=
CUDA_R_32F
;
CType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO23
;
// endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
data_type
=
HALF_DATATYPE
;
data_type
=
HALF_DATATYPE
;
AType
=
CUDA_R_16F
;
AType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_16F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
...
@@ -166,11 +169,14 @@ void generate_decoding_gemm_config(int batch_size,
...
@@ -166,11 +169,14 @@ void generate_decoding_gemm_config(int batch_size,
BType
=
CUDA_R_16BF
;
BType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#endif
#endif
using
scaleT
=
typename
ScaleTypeConverter
<
T
>::
Type
;
// using scaleT = typename ScaleTypeConverter<T>::Type;
using
scaleT
=
typename
ScaleTypeConverter
<
T
,
true
>::
Type
;
scaleT
alpha
=
(
scaleT
)
1.0
f
;
scaleT
alpha
=
(
scaleT
)
1.0
f
;
scaleT
beta
=
(
scaleT
)
0.0
f
;
scaleT
beta
=
(
scaleT
)
0.0
f
;
...
@@ -241,38 +247,39 @@ void generate_decoding_gemm_config(int batch_size,
...
@@ -241,38 +247,39 @@ void generate_decoding_gemm_config(int batch_size,
const
int
ALGO_COMBINATIONS
=
5000
;
const
int
ALGO_COMBINATIONS
=
5000
;
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
LtHgemmCustomFind
<
T
,
scaleT
>
(
ltHandle
,
// LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size
*
beam_width
,
// batch_size * beam_width,
seq_len
,
// seq_len,
head_num
,
// head_num,
size_per_head
,
// size_per_head,
n
,
// n,
m
,
// m,
k
,
// k,
&
alpha
,
// &alpha,
d_B
,
// d_B,
d_A
,
// d_A,
&
beta
,
// &beta,
d_C
,
// d_C,
cublas_workspace
,
// cublas_workspace,
workSpaceSize
,
// workSpaceSize,
fd
,
// fd,
perfResults
,
// perfResults,
ALGO_COMBINATIONS
);
// ALGO_COMBINATIONS);
if
(
perfResults
[
0
].
time
<
exec_time
)
{
// if (perfResults[0].time < exec_time) {
printPerfStructure
(
batch_size
*
beam_width
,
// printPerfStructure(batch_size * beam_width,
seq_len
,
// seq_len,
head_num
,
// head_num,
size_per_head
,
// size_per_head,
n
,
// n,
m
,
// m,
k
,
// k,
perfResults
[
0
],
// perfResults[0],
fd
,
// fd,
data_type
,
// data_type,
0
);
// 0);
}
// }
else
{
// else {
{
fprintf
(
fd
,
fprintf
(
fd
,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
...
...
src/turbomind/utils/gemm_test/encoder_gemm_func.cc
View file @
9484fd1c
...
@@ -127,8 +127,8 @@ void generate_encoder_gemm_config(
...
@@ -127,8 +127,8 @@ void generate_encoder_gemm_config(
cublasHandle_t
cublas_handle
;
cublasHandle_t
cublas_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
cublasLtHandle_t
ltHandle
;
//
cublasLtHandle_t ltHandle;
check_cuda_error
(
cublasLtCreate
(
&
ltHandle
));
//
check_cuda_error(cublasLtCreate(<Handle));
cudaDataType_t
AType
;
cudaDataType_t
AType
;
cudaDataType_t
BType
;
cudaDataType_t
BType
;
...
@@ -145,16 +145,19 @@ void generate_encoder_gemm_config(
...
@@ -145,16 +145,19 @@ void generate_encoder_gemm_config(
CType
=
CUDA_R_32F
;
CType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO23
;
// endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
data_type
=
HALF_DATATYPE
;
data_type
=
HALF_DATATYPE
;
AType
=
CUDA_R_16F
;
AType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_16F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
...
@@ -163,11 +166,14 @@ void generate_encoder_gemm_config(
...
@@ -163,11 +166,14 @@ void generate_encoder_gemm_config(
BType
=
CUDA_R_16BF
;
BType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#endif
#endif
using
scaleT
=
typename
ScaleTypeConverter
<
T
,
false
>::
Type
;
// using scaleT = typename ScaleTypeConverter<T, false>::Type;
using
scaleT
=
typename
ScaleTypeConverter
<
T
,
true
>::
Type
;
scaleT
alpha
=
(
scaleT
)
1.0
f
;
scaleT
alpha
=
(
scaleT
)
1.0
f
;
scaleT
beta
=
(
scaleT
)
0.0
f
;
scaleT
beta
=
(
scaleT
)
0.0
f
;
...
@@ -331,30 +337,31 @@ void generate_encoder_gemm_config(
...
@@ -331,30 +337,31 @@ void generate_encoder_gemm_config(
// Let try a fixed number of combinations
// Let try a fixed number of combinations
const
int
ALGO_COMBINATIONS
=
5000
;
const
int
ALGO_COMBINATIONS
=
5000
;
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
LtHgemmCustomFind
<
T
,
scaleT
>
(
ltHandle
,
// LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size
,
// batch_size,
seq_len
,
// seq_len,
head_num
,
// head_num,
size_per_head
,
// size_per_head,
n
,
// n,
m
,
// m,
k
,
// k,
&
alpha
,
// &alpha,
d_B
,
// d_B,
d_A
,
// d_A,
&
beta
,
// &beta,
d_C
,
// d_C,
cublas_workspace
,
// cublas_workspace,
workSpaceSize
,
// workSpaceSize,
fd
,
// fd,
perfResults
,
// perfResults,
ALGO_COMBINATIONS
);
// ALGO_COMBINATIONS);
if
(
perfResults
[
0
].
time
<
exec_time
)
{
// if (perfResults[0].time < exec_time) {
printPerfStructure
(
// printPerfStructure(
batch_size
,
seq_len
,
head_num
,
size_per_head
,
n
,
m
,
k
,
perfResults
[
0
],
fd
,
data_type
,
0
);
// batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time
=
perfResults
[
0
].
time
;
// exec_time = perfResults[0].time;
}
// }
else
{
// else {
{
fprintf
(
fd
,
fprintf
(
fd
,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
...
...
src/turbomind/utils/gemm_test/encoder_igemm_func.cc
View file @
9484fd1c
...
@@ -82,11 +82,11 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE
...
@@ -82,11 +82,11 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
swizzle
,
sizeof
(
swizzle
),
NULL
);
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
swizzle
,
sizeof
(
swizzle
),
NULL
);
cublasLtMatmulAlgoConfigGetAttribute
(
cublasLtMatmulAlgoConfigGetAttribute
(
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
),
NULL
);
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
),
NULL
);
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute
(
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
stages
,
sizeof
(
stages
),
NULL
);
//
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else
//
#else
stages
=
0
;
stages
=
0
;
#endif
//
#endif
printf
(
"algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
printf
(
"algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
"time %f workspace=%d mathMode=%d waves=%f
\n
"
,
"time %f workspace=%d mathMode=%d waves=%f
\n
"
,
...
@@ -148,11 +148,11 @@ int printBatchPerfStructure(
...
@@ -148,11 +148,11 @@ int printBatchPerfStructure(
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
swizzle
,
sizeof
(
swizzle
),
NULL
);
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
swizzle
,
sizeof
(
swizzle
),
NULL
);
cublasLtMatmulAlgoConfigGetAttribute
(
cublasLtMatmulAlgoConfigGetAttribute
(
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
),
NULL
);
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
),
NULL
);
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute
(
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
stages
,
sizeof
(
stages
),
NULL
);
//
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else
//
#else
stages
=
0
;
stages
=
0
;
#endif
//
#endif
printf
(
"algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
printf
(
"algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
"time %f workspace=%d mathMode=%d waves=%f
\n
"
,
"time %f workspace=%d mathMode=%d waves=%f
\n
"
,
...
@@ -234,22 +234,22 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, //
...
@@ -234,22 +234,22 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, //
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
auto
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
for
(
int
loop
=
0
;
loop
<
repeats
;
loop
++
)
{
for
(
int
loop
=
0
;
loop
<
repeats
;
loop
++
)
{
oneRunStatus
=
cublasLtMatmul
(
ltHandle
,
//
oneRunStatus = cublasLtMatmul(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
//
alpha,
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
//
beta,
C
,
//
C,
Cdesc
,
//
Cdesc,
D
,
//
D,
Ddesc
,
//
Ddesc,
&
algo
,
//
&algo,
workSpace
,
//
workSpace,
workSpaceSizeInBytes
,
//
workSpaceSizeInBytes,
stream
);
//
stream);
}
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
...
@@ -279,693 +279,693 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, //
...
@@ -279,693 +279,693 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, //
// Sample wrapper running through multiple algo and config attributes combination for INT8 gemm using cublasLt low-level
// Sample wrapper running through multiple algo and config attributes combination for INT8 gemm using cublasLt low-level
// API
// API
template
<
typename
T
,
typename
scaleT
>
//
template<typename T, typename scaleT>
int
LtIgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
scaleT
*
alpha
,
/* host pointer */
//
const scaleT* alpha, /* host pointer */
const
int8_t
*
A
,
//
const int8_t* A,
const
int8_t
*
B
,
//
const int8_t* B,
const
scaleT
*
beta
,
/* host pointer */
//
const scaleT* beta, /* host pointer */
T
*
C
,
//
T* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
)
//
FILE* fout)
{
//
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
//
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDesc_t
operationDesc
=
NULL
;
//
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t
Adesc
=
NULL
,
Bdesc
=
NULL
,
Cdesc
=
NULL
;
//
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaStream_t
stream
=
0
;
//
cudaStream_t stream = 0;
// SplitK value that we are going to try when SplitK is supported for a given algo
//
// SplitK value that we are going to try when SplitK is supported for a given algo
const
int
splitKSequenceA
[]
=
{
2
,
3
,
4
,
5
,
6
,
8
,
12
,
16
,
32
};
//
const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32};
// Let try a fixed number of combinations
//
// Let try a fixed number of combinations
#define ALGO_COMBINATIONS 50000
//
#define ALGO_COMBINATIONS 50000
int
AlgoCombinations
=
ALGO_COMBINATIONS
;
//
int AlgoCombinations = ALGO_COMBINATIONS;
int
AlgoCount
=
0
;
//
int AlgoCount = 0;
int
kernelRepeats
=
100
;
// number of time the CUDA kernels will be run back to back
//
int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
//
customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
int
nbAlgoIds
=
0
;
//
int nbAlgoIds = 0;
#define ALGO_IDS 100
//
#define ALGO_IDS 100
int
algoIdA
[
ALGO_IDS
];
//
int algoIdA[ALGO_IDS];
cudaDataType_t
Atype
,
Btype
,
Ctype
,
scaleType
;
//
cudaDataType_t Atype, Btype, Ctype, scaleType;
Atype
=
CUDA_R_8I
;
//
Atype = CUDA_R_8I;
Btype
=
CUDA_R_8I
;
//
Btype = CUDA_R_8I;
if
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
std
::
is_same
<
scaleT
,
int
>::
value
)
{
//
if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) {
Ctype
=
CUDA_R_32I
;
//
Ctype = CUDA_R_32I;
scaleType
=
CUDA_R_32I
;
//
scaleType = CUDA_R_32I;
}
//
}
else
if
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
std
::
is_same
<
scaleT
,
float
>::
value
)
{
//
else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) {
Ctype
=
CUDA_R_8I
;
//
Ctype = CUDA_R_8I;
scaleType
=
CUDA_R_32F
;
//
scaleType = CUDA_R_32F;
}
//
}
else
{
//
else {
printf
(
"[ERROR]<T,scaleT> of igemm is invalid
\n
"
);
//
printf("[ERROR]<T,scaleT> of igemm is invalid\n");
exit
(
-
1
);
//
exit(-1);
}
//
}
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
cublasComputeType_t
computeType
=
CUBLAS_COMPUTE_32I
;
// //
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else
// //
#else
cudaDataType_t
computeType
=
CUDA_R_32I
;
//
cudaDataType_t computeType = CUDA_R_32I;
#endif
// //
#endif
cublasOperation_t
opTranspose
=
CUBLAS_OP_T
;
//
cublasOperation_t opTranspose = CUBLAS_OP_T;
bool
use_ORDER_COL32_2R_4R4
=
false
;
//
bool use_ORDER_COL32_2R_4R4 = false;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
int
device
{
-
1
};
// //
int device{-1};
cudaGetDevice
(
&
device
);
// //
cudaGetDevice(&device);
cudaDeviceProp
props
;
// //
cudaDeviceProp props;
cudaGetDeviceProperties
(
&
props
,
device
);
// //
cudaGetDeviceProperties(&props, device);
if
(
props
.
major
*
10
+
props
.
minor
>=
80
)
{
// //
if (props.major * 10 + props.minor >= 80) {
use_ORDER_COL32_2R_4R4
=
true
;
// //
use_ORDER_COL32_2R_4R4 = true;
}
// //
}
#endif
// //
#endif
cublasLtOrder_t
order_COL32
=
CUBLASLT_ORDER_COL32
;
//
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t
order_matrixB
;
//
cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
if
(
use_ORDER_COL32_2R_4R4
)
{
// //
if (use_ORDER_COL32_2R_4R4) {
order_matrixB
=
CUBLASLT_ORDER_COL32_2R_4R4
;
// //
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
}
// //
}
else
{
// //
else {
order_matrixB
=
CUBLASLT_ORDER_COL4_4R2_8C
;
// //
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
}
// //
}
#else
// //
#else
order_matrixB
=
CUBLASLT_ORDER_COL4_4R2_8C
;
//
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif
// //
#endif
int
ldaTransform
=
32
*
m
;
//
int ldaTransform = 32 * m;
int
ldbTransform
;
//
int ldbTransform;
if
(
use_ORDER_COL32_2R_4R4
)
{
//
if (use_ORDER_COL32_2R_4R4) {
ldbTransform
=
32
*
((
n
+
32
-
1
)
/
32
)
*
32
;
//
ldbTransform = 32 * ((n + 32 - 1) / 32) * 32;
}
//
}
else
{
//
else {
ldbTransform
=
32
*
((
n
+
8
-
1
)
/
8
)
*
8
;
//
ldbTransform = 32 * ((n + 8 - 1) / 8) * 8;
}
//
}
int
ldcTransform
=
32
*
m
;
//
int ldcTransform = 32 * m;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
status
=
cublasLtMatmulDescCreate
(
&
operationDesc
,
computeType
,
scaleType
);
// //
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
// //
#else
status
=
cublasLtMatmulDescCreate
(
&
operationDesc
,
scaleType
);
//
status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif
// //
#endif
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
opTranspose
,
sizeof
(
cublasOperation_t
));
//
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
// Create matrix descriptors.
//
// Create matrix descriptors.
status
=
cublasLtMatrixLayoutCreate
(
&
Adesc
,
Atype
,
m
,
k
,
ldaTransform
);
//
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutSetAttribute
(
Adesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
order_COL32
,
sizeof
(
order_COL32
));
//
status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
Btype
,
n
,
k
,
ldbTransform
);
//
status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
//
status =
cublasLtMatrixLayoutSetAttribute
(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
order_matrixB
,
sizeof
(
order_matrixB
));
//
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
Ctype
,
m
,
n
,
ldcTransform
);
//
status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutSetAttribute
(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
order_COL32
,
sizeof
(
order_COL32
));
//
status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
// Request AlgoId available for IGEMM
//
// Request AlgoId available for IGEMM
status
=
cublasLtMatmulAlgoGetIds
(
//
status = cublasLtMatmulAlgoGetIds(
ltHandle
,
computeType
,
scaleType
,
Atype
,
Btype
,
Ctype
,
Ctype
,
ALGO_IDS
,
algoIdA
,
&
nbAlgoIds
);
//
ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
// Loop over the Algo IDs
//
// Loop over the Algo IDs
for
(
int
idx
=
0
;
(
idx
<
nbAlgoIds
)
&&
(
AlgoCount
<
AlgoCombinations
);
idx
++
)
{
//
for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) {
cublasLtMatmulAlgo_t
algo
;
//
cublasLtMatmulAlgo_t algo;
size_t
sizeWritten
=
0
;
//
size_t sizeWritten = 0;
/* Initialize algo structure with given Algp ID */
//
/* Initialize algo structure with given Algp ID */
status
=
//
status =
cublasLtMatmulAlgoInit
(
ltHandle
,
computeType
,
scaleType
,
Atype
,
Btype
,
Ctype
,
Ctype
,
algoIdA
[
idx
],
&
algo
);
//
cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
continue
;
//
continue;
}
//
}
// Query the tiles enums supported by that algo
//
// Query the tiles enums supported by that algo
cublasLtMatmulAlgoCapGetAttribute
(
&
algo
,
CUBLASLT_ALGO_CAP_TILE_IDS
,
NULL
,
0
,
&
sizeWritten
);
//
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten);
int
nbTiles
=
int
(
sizeWritten
/
sizeof
(
int
));
//
int nbTiles = int(sizeWritten / sizeof(int));
int
*
tileA
=
new
int
[
nbTiles
==
0
?
1
:
nbTiles
];
//
int* tileA = new int[nbTiles == 0 ? 1 : nbTiles];
if
(
nbTiles
==
0
)
{
//
if (nbTiles == 0) {
tileA
[
0
]
=
CUBLASLT_MATMUL_TILE_UNDEFINED
;
//
tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED;
nbTiles
=
1
;
//
nbTiles = 1;
}
//
}
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoCapGetAttribute
(
&
algo
,
CUBLASLT_ALGO_CAP_STAGES_IDS
,
NULL
,
0
,
&
sizeWritten
);
// //
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten);
int
nbStages
=
int
(
sizeWritten
/
sizeof
(
int
));
// //
int nbStages = int(sizeWritten / sizeof(int));
std
::
vector
<
int
>
stagesA
(
nbStages
==
0
?
1
:
nbStages
);
// //
std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages);
if
(
nbStages
==
0
)
{
// //
if (nbStages == 0) {
stagesA
[
0
]
=
CUBLASLT_MATMUL_STAGES_UNDEFINED
;
// //
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages
=
1
;
// //
nbStages = 1;
}
// //
}
else
{
// //
else {
cublasLtMatmulAlgoCapGetAttribute
(
// //
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_STAGES_IDS
,
stagesA
.
data
(),
sizeof
(
int
)
*
nbStages
,
&
sizeWritten
);
// //
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten);
}
// //
}
#endif
// //
#endif
int
splitkSupport
,
redMask
,
swizzlingMax
,
customOptionMax
;
//
int splitkSupport, redMask, swizzlingMax, customOptionMax;
// Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations
//
// Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_TILE_IDS
,
tileA
,
sizeof
(
int
)
*
nbTiles
,
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_SPLITK_SUPPORT
,
&
splitkSupport
,
sizeof
(
splitkSupport
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK
,
&
redMask
,
sizeof
(
redMask
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT
,
&
swizzlingMax
,
sizeof
(
swizzlingMax
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX
,
&
customOptionMax
,
sizeof
(
customOptionMax
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten);
/* Loop over the different tiles */
//
/* Loop over the different tiles */
for
(
int
tileIdx
=
0
;
tileIdx
<
nbTiles
;
tileIdx
++
)
{
//
for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) {
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
/* Loop over different stages count */
// //
/* Loop over different stages count */
for
(
int
stagesIdx
=
0
;
stagesIdx
<
nbStages
;
stagesIdx
++
)
{
// //
for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) {
cublasLtMatmulAlgoConfigSetAttribute
(
// //
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
stagesA
[
stagesIdx
],
sizeof
(
stagesA
[
stagesIdx
]));
// //
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]));
#endif
// //
#endif
/* Loop over the different custom option if any */
//
/* Loop over the different custom option if any */
for
(
int
customOption
=
0
;
customOption
<=
customOptionMax
;
customOption
++
)
{
//
for (int customOption = 0; customOption <= customOptionMax; customOption++) {
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
));
//
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption));
/* Loop over the CTAs swizzling support */
//
/* Loop over the CTAs swizzling support */
for
(
int
k
=
0
;
k
<=
swizzlingMax
;
k
++
)
{
//
for (int k = 0; k <= swizzlingMax; k++) {
int
splitK_trial
=
0
;
//
int splitK_trial = 0;
if
(
splitkSupport
)
{
//
if (splitkSupport) {
splitK_trial
+=
sizeof
(
splitKSequenceA
)
/
sizeof
(
splitKSequenceA
[
0
]);
//
splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
}
//
}
// Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case
//
// Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case
// where splitK is not enabled
//
// where splitK is not enabled
for
(
int
l
=
0
;
(
l
<
(
1
+
splitK_trial
))
&&
(
AlgoCount
<
AlgoCombinations
);
l
++
)
{
//
for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) {
/* Setup attribute of the algo to run */
//
/* Setup attribute of the algo to run */
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_TILE_ID
,
&
tileA
[
tileIdx
],
sizeof
(
tileA
[
tileIdx
]));
//
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]));
int
splitK_val
=
0
;
//
int splitK_val = 0;
int
redScheme
=
CUBLASLT_REDUCTION_SCHEME_NONE
;
//
int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
&
splitK_val
,
sizeof
(
splitK_val
));
//
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val));
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
k
,
sizeof
(
k
));
//
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k));
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
&
redScheme
,
sizeof
(
int
));
//
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int));
if
(
l
>
0
)
{
// Split-K case
//
if (l > 0) { // Split-K case
splitK_val
=
splitKSequenceA
[
l
-
1
];
//
splitK_val = splitKSequenceA[l - 1];
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
//
cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
//
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&
splitKSequenceA
[
l
-
1
],
//
&splitKSequenceA[l - 1],
sizeof
(
splitKSequenceA
[
l
-
1
]));
//
sizeof(splitKSequenceA[l - 1]));
/* Going over all the reduction scheme */
//
/* Going over all the reduction scheme */
for
(
redScheme
=
1
;
//
for (redScheme = 1;
redScheme
<=
(
int
)
CUBLASLT_REDUCTION_SCHEME_MASK
&&
(
AlgoCount
<
AlgoCombinations
);
//
redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations);
redScheme
=
redScheme
<<
1
)
{
//
redScheme = redScheme << 1) {
if
(
redScheme
&
redMask
)
{
//
if (redScheme & redMask) {
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
//
cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
//
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&
redScheme
,
//
&redScheme,
sizeof
(
redScheme
));
//
sizeof(redScheme));
status
=
customMatmulRun
(
ltHandle
,
//
status = customMatmulRun(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
/* host or device pointer */
//
alpha, /* host or device pointer */
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
/* host or device pointer */
//
beta, /* host or device pointer */
C
,
//
C,
Cdesc
,
//
Cdesc,
C
,
//
C,
Cdesc
,
//
Cdesc,
algo
,
//
algo,
kernelRepeats
,
//
kernelRepeats,
workSpace
,
//
workSpace,
workSpaceSize
,
//
workSpaceSize,
perfResults
[
AlgoCount
],
//
perfResults[AlgoCount],
stream
);
//
stream);
perfResults
[
AlgoCount
].
status
=
status
;
//
perfResults[AlgoCount].status = status;
if
(
status
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount
++
;
//
AlgoCount++;
}
//
}
}
// end if
//
} // end if
}
// end for
//
} // end for
}
//
}
else
{
// Non-splitK case
//
else { // Non-splitK case
/* if user preference is ok with workspace */
//
/* if user preference is ok with workspace */
if
(
AlgoCount
<
AlgoCombinations
)
{
//
if (AlgoCount < AlgoCombinations) {
status
=
customMatmulRun
(
ltHandle
,
//
status = customMatmulRun(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
/* host or device pointer */
//
alpha, /* host or device pointer */
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
/* host or device pointer */
//
beta, /* host or device pointer */
C
,
//
C,
Cdesc
,
//
Cdesc,
C
,
//
C,
Cdesc
,
//
Cdesc,
algo
,
//
algo,
kernelRepeats
,
//
kernelRepeats,
workSpace
,
//
workSpace,
workSpaceSize
,
//
workSpaceSize,
perfResults
[
AlgoCount
],
//
perfResults[AlgoCount],
stream
);
//
stream);
perfResults
[
AlgoCount
].
status
=
status
;
//
perfResults[AlgoCount].status = status;
if
(
status
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount
++
;
//
AlgoCount++;
}
//
}
}
//
}
}
//
}
}
// end l
//
} // end l
}
// end k
//
} // end k
}
// end customOption
//
} // end customOption
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
}
// end stagesIdx
// //
} // end stagesIdx
#endif
// //
#endif
}
// end tileIdx
//
} // end tileIdx
delete
[]
tileA
;
//
delete[] tileA;
}
// end idx
//
} // end idx
// Sort the results per run duration
//
// Sort the results per run duration
std
::
sort
(
perfResults
,
perfResults
+
AlgoCount
,
time_compare
);
//
std::sort(perfResults, perfResults + AlgoCount, time_compare);
// Print timing and perf details
//
// Print timing and perf details
for
(
int
i
=
0
,
hasPrint
=
0
;
i
<
AlgoCount
;
i
++
)
{
//
for (int i = 0, hasPrint = 0; i < AlgoCount; i++) {
printf
(
"result %03d : "
,
i
);
//
printf("result %03d : ", i);
hasPrint
=
printPerfStructure
(
m
,
n
,
k
,
perfResults
[
i
],
fout
,
hasPrint
);
//
hasPrint = printPerfStructure(m, n, k, perfResults[i], fout, hasPrint);
}
//
}
CLEANUP:
//
CLEANUP:
// Descriptors are no longer needed as all GPU work was already enqueued
//
// Descriptors are no longer needed as all GPU work was already enqueued
if
(
Cdesc
)
{
//
if (Cdesc) {
cublasLtMatrixLayoutDestroy
(
Cdesc
);
//
cublasLtMatrixLayoutDestroy(Cdesc);
}
//
}
if
(
Bdesc
)
{
//
if (Bdesc) {
cublasLtMatrixLayoutDestroy
(
Bdesc
);
//
cublasLtMatrixLayoutDestroy(Bdesc);
}
//
}
if
(
Adesc
)
{
//
if (Adesc) {
cublasLtMatrixLayoutDestroy
(
Adesc
);
//
cublasLtMatrixLayoutDestroy(Adesc);
}
//
}
if
(
operationDesc
)
{
//
if (operationDesc) {
cublasLtMatmulDescDestroy
(
operationDesc
);
//
cublasLtMatmulDescDestroy(operationDesc);
}
//
}
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
//
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
//
}
template
int
LtIgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
int
*
alpha
,
/* host pointer */
//
const int* alpha, /* host pointer */
const
int8_t
*
A
,
//
const int8_t* A,
const
int8_t
*
B
,
//
const int8_t* B,
const
int
*
beta
,
/* host pointer */
//
const int* beta, /* host pointer */
int32_t
*
C
,
//
int32_t* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
);
//
FILE* fout);
template
int
LtIgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
float
*
alpha
,
/* host pointer */
//
const float* alpha, /* host pointer */
const
int8_t
*
A
,
//
const int8_t* A,
const
int8_t
*
B
,
//
const int8_t* B,
const
float
*
beta
,
/* host pointer */
//
const float* beta, /* host pointer */
int8_t
*
C
,
//
int8_t* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
);
//
FILE* fout);
template
<
typename
T
,
typename
scaleT
>
//
template<typename T, typename scaleT>
int
LtBatchIgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int
batchCount
,
//
int batchCount,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
scaleT
*
alpha
,
/* host pointer */
//
const scaleT* alpha, /* host pointer */
const
int8_t
*
A
,
//
const int8_t* A,
const
int8_t
*
B
,
//
const int8_t* B,
const
scaleT
*
beta
,
/* host pointer */
//
const scaleT* beta, /* host pointer */
T
*
C
,
//
T* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
)
//
FILE* fout)
{
//
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
//
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDesc_t
operationDesc
=
NULL
;
//
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t
Adesc
=
NULL
,
Bdesc
=
NULL
,
Cdesc
=
NULL
;
//
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaStream_t
stream
=
0
;
//
cudaStream_t stream = 0;
// SplitK value that we are going to try when SplitK is supported for a given algo
//
// SplitK value that we are going to try when SplitK is supported for a given algo
const
int
splitKSequenceA
[]
=
{
2
,
3
,
4
,
5
,
6
,
8
,
12
,
16
,
32
};
//
const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32};
// Let try a fixed number of combinations
//
// Let try a fixed number of combinations
#define ALGO_COMBINATIONS 50000
//
#define ALGO_COMBINATIONS 50000
int
AlgoCombinations
=
ALGO_COMBINATIONS
;
//
int AlgoCombinations = ALGO_COMBINATIONS;
int
AlgoCount
=
0
;
//
int AlgoCount = 0;
int
kernelRepeats
=
100
;
// number of time the CUDA kernels will be run back to back
//
int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
//
customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
int
nbAlgoIds
=
0
;
//
int nbAlgoIds = 0;
#define ALGO_IDS 100
//
#define ALGO_IDS 100
int
algoIdA
[
ALGO_IDS
];
//
int algoIdA[ALGO_IDS];
cudaDataType_t
Atype
,
Btype
,
Ctype
,
scaleType
;
//
cudaDataType_t Atype, Btype, Ctype, scaleType;
Atype
=
CUDA_R_8I
;
//
Atype = CUDA_R_8I;
Btype
=
CUDA_R_8I
;
//
Btype = CUDA_R_8I;
if
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
std
::
is_same
<
scaleT
,
int
>::
value
)
{
//
if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) {
Ctype
=
CUDA_R_32I
;
//
Ctype = CUDA_R_32I;
scaleType
=
CUDA_R_32I
;
//
scaleType = CUDA_R_32I;
}
//
}
else
if
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
std
::
is_same
<
scaleT
,
float
>::
value
)
{
//
else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) {
Ctype
=
CUDA_R_8I
;
//
Ctype = CUDA_R_8I;
scaleType
=
CUDA_R_32F
;
//
scaleType = CUDA_R_32F;
}
//
}
else
{
//
else {
printf
(
"[ERROR]<T,scaleT> of igemm is invalid
\n
"
);
//
printf("[ERROR]<T,scaleT> of igemm is invalid\n");
exit
(
-
1
);
//
exit(-1);
}
//
}
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
cublasComputeType_t
computeType
=
CUBLAS_COMPUTE_32I
;
// //
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else
// //
#else
cudaDataType_t
computeType
=
CUDA_R_32I
;
//
cudaDataType_t computeType = CUDA_R_32I;
#endif
// //
#endif
cublasOperation_t
opTranspose
=
CUBLAS_OP_T
;
//
cublasOperation_t opTranspose = CUBLAS_OP_T;
bool
use_ORDER_COL32_2R_4R4
=
false
;
//
bool use_ORDER_COL32_2R_4R4 = false;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
int
device
{
-
1
};
// //
int device{-1};
cudaGetDevice
(
&
device
);
// //
cudaGetDevice(&device);
cudaDeviceProp
props
;
// //
cudaDeviceProp props;
cudaGetDeviceProperties
(
&
props
,
device
);
// //
cudaGetDeviceProperties(&props, device);
if
(
props
.
major
*
10
+
props
.
minor
>=
80
)
{
// //
if (props.major * 10 + props.minor >= 80) {
use_ORDER_COL32_2R_4R4
=
true
;
// //
use_ORDER_COL32_2R_4R4 = true;
}
// //
}
#endif
// //
#endif
cublasLtOrder_t
order_COL32
=
CUBLASLT_ORDER_COL32
;
//
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t
order_matrixB
;
//
cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
if
(
use_ORDER_COL32_2R_4R4
)
{
// //
if (use_ORDER_COL32_2R_4R4) {
order_matrixB
=
CUBLASLT_ORDER_COL32_2R_4R4
;
// //
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
}
// //
}
else
{
// //
else {
order_matrixB
=
CUBLASLT_ORDER_COL4_4R2_8C
;
// //
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
}
// //
}
#else
// //
#else
order_matrixB
=
CUBLASLT_ORDER_COL4_4R2_8C
;
//
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif
// //
#endif
int
ldaTransform
=
32
*
m
;
//
int ldaTransform = 32 * m;
int
ldbTransform
;
//
int ldbTransform;
if
(
use_ORDER_COL32_2R_4R4
)
{
//
if (use_ORDER_COL32_2R_4R4) {
ldbTransform
=
32
*
((
n
+
32
-
1
)
/
32
)
*
32
;
//
ldbTransform = 32 * ((n + 32 - 1) / 32) * 32;
}
//
}
else
{
//
else {
ldbTransform
=
32
*
((
n
+
8
-
1
)
/
8
)
*
8
;
//
ldbTransform = 32 * ((n + 8 - 1) / 8) * 8;
}
//
}
int
ldcTransform
=
32
*
m
;
//
int ldcTransform = 32 * m;
int64_t
stridea
,
strideb
,
stridec
;
//
int64_t stridea, strideb, stridec;
stridea
=
m
*
k
;
//
stridea = m * k;
strideb
=
n
*
k
;
//
strideb = n * k;
stridec
=
m
*
n
;
//
stridec = m * n;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
status
=
cublasLtMatmulDescCreate
(
&
operationDesc
,
computeType
,
scaleType
);
// //
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
// //
#else
status
=
cublasLtMatmulDescCreate
(
&
operationDesc
,
scaleType
);
//
status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif
// //
#endif
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
opTranspose
,
sizeof
(
cublasOperation_t
));
//
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
// Create matrix descriptors.
//
// Create matrix descriptors.
status
=
cublasLtMatrixLayoutCreate
(
&
Adesc
,
Atype
,
m
,
k
,
ldaTransform
);
//
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutSetAttribute
(
Adesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
order_COL32
,
sizeof
(
order_COL32
));
//
status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
cublasLtMatrixLayoutSetAttribute
(
Adesc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batchCount
,
sizeof
(
batchCount
));
//
cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute
(
Adesc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
stridea
,
sizeof
(
stridea
));
//
cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, sizeof(stridea));
status
=
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
Btype
,
n
,
k
,
ldbTransform
);
//
status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
//
status =
cublasLtMatrixLayoutSetAttribute
(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
order_matrixB
,
sizeof
(
order_matrixB
));
//
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
cublasLtMatrixLayoutSetAttribute
(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batchCount
,
sizeof
(
batchCount
));
//
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute
(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideb
,
sizeof
(
strideb
));
//
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, sizeof(strideb));
status
=
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
Ctype
,
m
,
n
,
ldcTransform
);
//
status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutSetAttribute
(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
order_COL32
,
sizeof
(
order_COL32
));
//
status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
cublasLtMatrixLayoutSetAttribute
(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batchCount
,
sizeof
(
batchCount
));
//
cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute
(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
stridec
,
sizeof
(
stridec
));
//
cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec));
// Request AlgoId available for IGEMM
//
// Request AlgoId available for IGEMM
status
=
cublasLtMatmulAlgoGetIds
(
//
status = cublasLtMatmulAlgoGetIds(
ltHandle
,
computeType
,
scaleType
,
Atype
,
Btype
,
Ctype
,
Ctype
,
ALGO_IDS
,
algoIdA
,
&
nbAlgoIds
);
//
ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
// Loop over the Algo IDs
//
// Loop over the Algo IDs
for
(
int
idx
=
0
;
(
idx
<
nbAlgoIds
)
&&
(
AlgoCount
<
AlgoCombinations
);
idx
++
)
{
//
for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) {
cublasLtMatmulAlgo_t
algo
;
//
cublasLtMatmulAlgo_t algo;
size_t
sizeWritten
=
0
;
//
size_t sizeWritten = 0;
/* Initialize algo structure with given Algp ID */
//
/* Initialize algo structure with given Algp ID */
status
=
//
status =
cublasLtMatmulAlgoInit
(
ltHandle
,
computeType
,
scaleType
,
Atype
,
Btype
,
Ctype
,
Ctype
,
algoIdA
[
idx
],
&
algo
);
//
cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
continue
;
//
continue;
}
//
}
// Query the tiles enums supported by that algo
//
// Query the tiles enums supported by that algo
cublasLtMatmulAlgoCapGetAttribute
(
&
algo
,
CUBLASLT_ALGO_CAP_TILE_IDS
,
NULL
,
0
,
&
sizeWritten
);
//
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten);
int
nbTiles
=
int
(
sizeWritten
/
sizeof
(
int
));
//
int nbTiles = int(sizeWritten / sizeof(int));
int
*
tileA
=
new
int
[
nbTiles
==
0
?
1
:
nbTiles
];
//
int* tileA = new int[nbTiles == 0 ? 1 : nbTiles];
if
(
nbTiles
==
0
)
{
//
if (nbTiles == 0) {
tileA
[
0
]
=
CUBLASLT_MATMUL_TILE_UNDEFINED
;
//
tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED;
nbTiles
=
1
;
//
nbTiles = 1;
}
//
}
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoCapGetAttribute
(
&
algo
,
CUBLASLT_ALGO_CAP_STAGES_IDS
,
NULL
,
0
,
&
sizeWritten
);
// //
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten);
int
nbStages
=
int
(
sizeWritten
/
sizeof
(
int
));
// //
int nbStages = int(sizeWritten / sizeof(int));
std
::
vector
<
int
>
stagesA
(
nbStages
==
0
?
1
:
nbStages
);
// //
std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages);
if
(
nbStages
==
0
)
{
// //
if (nbStages == 0) {
stagesA
[
0
]
=
CUBLASLT_MATMUL_STAGES_UNDEFINED
;
// //
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages
=
1
;
// //
nbStages = 1;
}
// //
}
else
{
// //
else {
cublasLtMatmulAlgoCapGetAttribute
(
// //
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_STAGES_IDS
,
stagesA
.
data
(),
sizeof
(
int
)
*
nbStages
,
&
sizeWritten
);
// //
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten);
}
// //
}
#endif
// //
#endif
int
splitkSupport
,
redMask
,
swizzlingMax
,
customOptionMax
;
//
int splitkSupport, redMask, swizzlingMax, customOptionMax;
// Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations
//
// Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_TILE_IDS
,
tileA
,
sizeof
(
int
)
*
nbTiles
,
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_SPLITK_SUPPORT
,
&
splitkSupport
,
sizeof
(
splitkSupport
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK
,
&
redMask
,
sizeof
(
redMask
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT
,
&
swizzlingMax
,
sizeof
(
swizzlingMax
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX
,
&
customOptionMax
,
sizeof
(
customOptionMax
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten);
/* Loop over the different tiles */
//
/* Loop over the different tiles */
for
(
int
tileIdx
=
0
;
tileIdx
<
nbTiles
;
tileIdx
++
)
{
//
for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) {
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
/* Loop over different stages count */
// //
/* Loop over different stages count */
for
(
int
stagesIdx
=
0
;
stagesIdx
<
nbStages
;
stagesIdx
++
)
{
// //
for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) {
cublasLtMatmulAlgoConfigSetAttribute
(
// //
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
stagesA
[
stagesIdx
],
sizeof
(
stagesA
[
stagesIdx
]));
// //
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]));
#endif
// //
#endif
/* Loop over the different custom option if any */
//
/* Loop over the different custom option if any */
for
(
int
customOption
=
0
;
customOption
<=
customOptionMax
;
customOption
++
)
{
//
for (int customOption = 0; customOption <= customOptionMax; customOption++) {
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
));
//
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption));
/* Loop over the CTAs swizzling support */
//
/* Loop over the CTAs swizzling support */
for
(
int
k
=
0
;
k
<=
swizzlingMax
;
k
++
)
{
//
for (int k = 0; k <= swizzlingMax; k++) {
int
splitK_trial
=
0
;
//
int splitK_trial = 0;
if
(
splitkSupport
)
{
//
if (splitkSupport) {
splitK_trial
+=
sizeof
(
splitKSequenceA
)
/
sizeof
(
splitKSequenceA
[
0
]);
//
splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
}
//
}
// Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case
//
// Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case
// where splitK is not enabled
//
// where splitK is not enabled
for
(
int
l
=
0
;
(
l
<
(
1
+
splitK_trial
))
&&
(
AlgoCount
<
AlgoCombinations
);
l
++
)
{
//
for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) {
/* Setup attribute of the algo to run */
//
/* Setup attribute of the algo to run */
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_TILE_ID
,
&
tileA
[
tileIdx
],
sizeof
(
tileA
[
tileIdx
]));
//
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]));
int
splitK_val
=
0
;
//
int splitK_val = 0;
int
redScheme
=
CUBLASLT_REDUCTION_SCHEME_NONE
;
//
int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
&
splitK_val
,
sizeof
(
splitK_val
));
//
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val));
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
k
,
sizeof
(
k
));
//
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k));
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
&
redScheme
,
sizeof
(
int
));
//
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int));
if
(
l
>
0
)
{
// Split-K case
//
if (l > 0) { // Split-K case
splitK_val
=
splitKSequenceA
[
l
-
1
];
//
splitK_val = splitKSequenceA[l - 1];
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
//
cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
//
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&
splitKSequenceA
[
l
-
1
],
//
&splitKSequenceA[l - 1],
sizeof
(
splitKSequenceA
[
l
-
1
]));
//
sizeof(splitKSequenceA[l - 1]));
/* Going over all the reduction scheme */
//
/* Going over all the reduction scheme */
for
(
redScheme
=
1
;
//
for (redScheme = 1;
redScheme
<=
(
int
)
CUBLASLT_REDUCTION_SCHEME_MASK
&&
(
AlgoCount
<
AlgoCombinations
);
//
redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations);
redScheme
=
redScheme
<<
1
)
{
//
redScheme = redScheme << 1) {
if
(
redScheme
&
redMask
)
{
//
if (redScheme & redMask) {
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
//
cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
//
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&
redScheme
,
//
&redScheme,
sizeof
(
redScheme
));
//
sizeof(redScheme));
status
=
customMatmulRun
(
ltHandle
,
//
status = customMatmulRun(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
/* host or device pointer */
//
alpha, /* host or device pointer */
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
/* host or device pointer */
//
beta, /* host or device pointer */
C
,
//
C,
Cdesc
,
//
Cdesc,
C
,
//
C,
Cdesc
,
//
Cdesc,
algo
,
//
algo,
kernelRepeats
,
//
kernelRepeats,
workSpace
,
//
workSpace,
workSpaceSize
,
//
workSpaceSize,
perfResults
[
AlgoCount
],
//
perfResults[AlgoCount],
stream
);
//
stream);
perfResults
[
AlgoCount
].
status
=
status
;
//
perfResults[AlgoCount].status = status;
if
(
status
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount
++
;
//
AlgoCount++;
}
//
}
}
// end if
//
} // end if
}
// end for
//
} // end for
}
//
}
else
{
// Non-splitK case
//
else { // Non-splitK case
/* if user preference is ok with workspace */
//
/* if user preference is ok with workspace */
if
(
AlgoCount
<
AlgoCombinations
)
{
//
if (AlgoCount < AlgoCombinations) {
status
=
customMatmulRun
(
ltHandle
,
//
status = customMatmulRun(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
/* host or device pointer */
//
alpha, /* host or device pointer */
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
/* host or device pointer */
//
beta, /* host or device pointer */
C
,
//
C,
Cdesc
,
//
Cdesc,
C
,
//
C,
Cdesc
,
//
Cdesc,
algo
,
//
algo,
kernelRepeats
,
//
kernelRepeats,
workSpace
,
//
workSpace,
workSpaceSize
,
//
workSpaceSize,
perfResults
[
AlgoCount
],
//
perfResults[AlgoCount],
stream
);
//
stream);
perfResults
[
AlgoCount
].
status
=
status
;
//
perfResults[AlgoCount].status = status;
if
(
status
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount
++
;
//
AlgoCount++;
}
//
}
}
//
}
}
//
}
}
// end l
//
} // end l
}
// end k
//
} // end k
}
// end customOption
//
} // end customOption
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
}
// end stagesIdx
// //
} // end stagesIdx
#endif
// //
#endif
}
// end tileIdx
//
} // end tileIdx
delete
[]
tileA
;
//
delete[] tileA;
}
// end idx
//
} // end idx
// Sort the results per run duration
//
// Sort the results per run duration
std
::
sort
(
perfResults
,
perfResults
+
AlgoCount
,
time_compare
);
//
std::sort(perfResults, perfResults + AlgoCount, time_compare);
// Print timing and perf details
//
// Print timing and perf details
for
(
int
i
=
0
,
hasPrint
=
0
;
i
<
AlgoCount
;
i
++
)
{
//
for (int i = 0, hasPrint = 0; i < AlgoCount; i++) {
printf
(
"result %03d : "
,
i
);
//
printf("result %03d : ", i);
hasPrint
=
printBatchPerfStructure
(
batchCount
,
m
,
n
,
k
,
perfResults
[
i
],
fout
,
hasPrint
);
//
hasPrint = printBatchPerfStructure(batchCount, m, n, k, perfResults[i], fout, hasPrint);
}
//
}
CLEANUP:
//
CLEANUP:
// Descriptors are no longer needed as all GPU work was already enqueued
//
// Descriptors are no longer needed as all GPU work was already enqueued
if
(
Cdesc
)
{
//
if (Cdesc) {
cublasLtMatrixLayoutDestroy
(
Cdesc
);
//
cublasLtMatrixLayoutDestroy(Cdesc);
}
//
}
if
(
Bdesc
)
{
//
if (Bdesc) {
cublasLtMatrixLayoutDestroy
(
Bdesc
);
//
cublasLtMatrixLayoutDestroy(Bdesc);
}
//
}
if
(
Adesc
)
{
//
if (Adesc) {
cublasLtMatrixLayoutDestroy
(
Adesc
);
//
cublasLtMatrixLayoutDestroy(Adesc);
}
//
}
if
(
operationDesc
)
{
//
if (operationDesc) {
cublasLtMatmulDescDestroy
(
operationDesc
);
//
cublasLtMatmulDescDestroy(operationDesc);
}
//
}
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
//
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
//
}
template
int
LtBatchIgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int
batchCount
,
//
int batchCount,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
int
*
alpha
,
/* host pointer */
//
const int* alpha, /* host pointer */
const
int8_t
*
A
,
//
const int8_t* A,
const
int8_t
*
B
,
//
const int8_t* B,
const
int
*
beta
,
/* host pointer */
//
const int* beta, /* host pointer */
int32_t
*
C
,
//
int32_t* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
);
//
FILE* fout);
template
int
LtBatchIgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int
batchCount
,
//
int batchCount,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
float
*
alpha
,
/* host pointer */
//
const float* alpha, /* host pointer */
const
int8_t
*
A
,
//
const int8_t* A,
const
int8_t
*
B
,
//
const int8_t* B,
const
float
*
beta
,
/* host pointer */
//
const float* beta, /* host pointer */
int8_t
*
C
,
//
int8_t* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
);
//
FILE* fout);
// initialize matrix in column-major
// initialize matrix in column-major
void
matInit
(
int
rows
,
int
cols
,
int8_t
*
p
,
int
ld
)
void
matInit
(
int
rows
,
int
cols
,
int8_t
*
p
,
int
ld
)
...
...
src/turbomind/utils/gemm_test/gemm_func.cc
View file @
9484fd1c
...
@@ -52,11 +52,11 @@ int printPerfStructure(int batch_size,
...
@@ -52,11 +52,11 @@ int printPerfStructure(int batch_size,
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
swizzle
,
sizeof
(
swizzle
),
NULL
);
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
swizzle
,
sizeof
(
swizzle
),
NULL
);
cublasLtMatmulAlgoConfigGetAttribute
(
cublasLtMatmulAlgoConfigGetAttribute
(
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
),
NULL
);
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
),
NULL
);
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute
(
matmulAlgo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
stages
,
sizeof
(
stages
),
NULL
);
//
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else
//
#else
stages
=
0
;
stages
=
0
;
#endif
//
#endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
uint16_t
inner_shapeId
,
cluster_shapeId
;
uint16_t
inner_shapeId
,
cluster_shapeId
;
cublasLtMatmulAlgoConfigGetAttribute
(
cublasLtMatmulAlgoConfigGetAttribute
(
...
@@ -74,9 +74,9 @@ int printPerfStructure(int batch_size,
...
@@ -74,9 +74,9 @@ int printPerfStructure(int batch_size,
#endif
#endif
printf
(
"algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d "
printf
(
"algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d "
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
"stages=%d "
//
"stages=%d "
#endif
//
#endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
"inner_shapeId=%d cluster_shapeId=%d"
"inner_shapeId=%d cluster_shapeId=%d"
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
...
@@ -91,9 +91,9 @@ int printPerfStructure(int batch_size,
...
@@ -91,9 +91,9 @@ int printPerfStructure(int batch_size,
reductionScheme
,
reductionScheme
,
swizzle
,
swizzle
,
customOption
,
customOption
,
#if (CUDART_VERSION >= 11000)
//
#if (CUDART_VERSION >= 11000)
stages
,
//
stages,
#endif
//
#endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
inner_shapeId
,
inner_shapeId
,
cluster_shapeId
,
cluster_shapeId
,
...
@@ -154,704 +154,704 @@ static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMa
...
@@ -154,704 +154,704 @@ static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMa
return
((
perf_a
.
status
==
CUBLAS_STATUS_SUCCESS
)
&&
(
perf_a
.
time
<
perf_b
.
time
));
return
((
perf_a
.
status
==
CUBLAS_STATUS_SUCCESS
)
&&
(
perf_a
.
time
<
perf_b
.
time
));
}
}
static
cublasStatus_t
customMatmulRun
(
cublasLtHandle_t
ltHandle
,
// to get the capabilities (required a GPU)
//
static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU)
cublasLtMatmulDesc_t
operationDesc
,
//
cublasLtMatmulDesc_t operationDesc,
const
void
*
alpha
,
/* host or device pointer */
//
const void* alpha, /* host or device pointer */
const
void
*
A
,
//
const void* A,
cublasLtMatrixLayout_t
Adesc
,
//
cublasLtMatrixLayout_t Adesc,
const
void
*
B
,
//
const void* B,
cublasLtMatrixLayout_t
Bdesc
,
//
cublasLtMatrixLayout_t Bdesc,
const
void
*
beta
,
/* host or device pointer */
//
const void* beta, /* host or device pointer */
const
void
*
C
,
//
const void* C,
cublasLtMatrixLayout_t
Cdesc
,
//
cublasLtMatrixLayout_t Cdesc,
void
*
D
,
//
void* D,
cublasLtMatrixLayout_t
Ddesc
,
//
cublasLtMatrixLayout_t Ddesc,
const
cublasLtMatmulAlgo_t
&
algo
,
//
const cublasLtMatmulAlgo_t& algo,
int
kernelRepeats
,
//
int kernelRepeats,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSizeInBytes
,
//
size_t workSpaceSizeInBytes,
customMatmulPerf_t
&
perfResults
,
//
customMatmulPerf_t& perfResults,
cudaStream_t
stream
,
//
cudaStream_t stream,
cudaEvent_t
&
startEvent
,
//
cudaEvent_t& startEvent,
cudaEvent_t
&
stopEvent
)
//
cudaEvent_t& stopEvent)
{
//
{
cublasLtMatmulHeuristicResult_t
heurResult
;
//
cublasLtMatmulHeuristicResult_t heurResult;
/* Looping over the Algo */
//
/* Looping over the Algo */
int
repeats
=
kernelRepeats
;
//
int repeats = kernelRepeats;
cublasStatus_t
algoStatus
=
//
cublasStatus_t algoStatus =
cublasLtMatmulAlgoCheck
(
ltHandle
,
operationDesc
,
Adesc
,
Bdesc
,
Cdesc
,
Ddesc
,
&
algo
,
&
heurResult
);
//
cublasLtMatmulAlgoCheck(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, &algo, &heurResult);
if
(
algoStatus
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
if
(
heurResult
.
workspaceSize
<=
workSpaceSizeInBytes
)
{
//
if (heurResult.workspaceSize <= workSpaceSizeInBytes) {
cudaError_t
err
,
err1
,
err2
,
err3
;
//
cudaError_t err, err1, err2, err3;
err
=
cudaEventRecord
(
startEvent
,
stream
);
//
err = cudaEventRecord(startEvent, stream);
for
(
int
loop
=
0
;
loop
<
repeats
;
loop
++
)
{
//
for (int loop = 0; loop < repeats; loop++) {
cublasStatus_t
oneRunStatus
=
cublasLtMatmul
(
ltHandle
,
//
cublasStatus_t oneRunStatus = cublasLtMatmul(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
//
alpha,
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
//
beta,
C
,
//
C,
Cdesc
,
//
Cdesc,
D
,
//
D,
Ddesc
,
//
Ddesc,
&
algo
,
//
&algo,
workSpace
,
//
workSpace,
workSpaceSizeInBytes
,
//
workSpaceSizeInBytes,
stream
);
//
stream);
if
(
oneRunStatus
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (oneRunStatus != CUBLAS_STATUS_SUCCESS) {
algoStatus
=
oneRunStatus
;
//
algoStatus = oneRunStatus;
break
;
//
break;
}
//
}
}
//
}
err1
=
cudaEventRecord
(
stopEvent
,
stream
);
//
err1 = cudaEventRecord(stopEvent, stream);
err2
=
cudaEventSynchronize
(
stopEvent
);
//
err2 = cudaEventSynchronize(stopEvent);
float
time
;
//
float time;
err3
=
cudaEventElapsedTime
(
&
time
,
startEvent
,
stopEvent
);
//
err3 = cudaEventElapsedTime(&time, startEvent, stopEvent);
if
((
err
!=
cudaSuccess
)
||
(
err1
!=
cudaSuccess
)
||
(
err2
!=
cudaSuccess
)
||
(
err3
!=
cudaSuccess
))
{
//
if ((err != cudaSuccess) || (err1 != cudaSuccess) || (err2 != cudaSuccess) || (err3 != cudaSuccess)) {
algoStatus
=
CUBLAS_STATUS_INTERNAL_ERROR
;
//
algoStatus = CUBLAS_STATUS_INTERNAL_ERROR;
}
//
}
// For the moment only add successful findings
//
// For the moment only add successful findings
if
(
algoStatus
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
perfResults
.
algo
=
algo
;
//
perfResults.algo = algo;
perfResults
.
time
=
time
/
repeats
;
//
perfResults.time = time / repeats;
perfResults
.
workspaceSize
=
heurResult
.
workspaceSize
;
//
perfResults.workspaceSize = heurResult.workspaceSize;
perfResults
.
wavesCount
=
heurResult
.
wavesCount
;
//
perfResults.wavesCount = heurResult.wavesCount;
}
//
}
}
//
}
else
{
//
else {
// printf("not enough workspace! %ld\n", heurResult.workspaceSize);
//
// printf("not enough workspace! %ld\n", heurResult.workspaceSize);
algoStatus
=
CUBLAS_STATUS_NOT_SUPPORTED
;
// Not enough workspace
//
algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
}
//
}
}
//
}
return
algoStatus
;
//
return algoStatus;
}
//
}
template
<
typename
T
,
typename
scaleT
>
//
template<typename T, typename scaleT>
int
LtHgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int
batch_size
,
//
int batch_size,
int
seq_len
,
//
int seq_len,
int
head_num
,
//
int head_num,
int
size_per_head
,
//
int size_per_head,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
scaleT
*
alpha
,
/* host pointer */
//
const scaleT* alpha, /* host pointer */
const
T
*
A
,
//
const T* A,
const
T
*
B
,
//
const T* B,
const
scaleT
*
beta
,
/* host pointer */
//
const scaleT* beta, /* host pointer */
T
*
C
,
//
T* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
,
//
FILE* fout,
customMatmulPerf_t
perfResults
[],
//
customMatmulPerf_t perfResults[],
int
AlgoCombinations
,
//
int AlgoCombinations,
cudaDataType_t
dtype_fp8
,
//
cudaDataType_t dtype_fp8,
int
batchCount
,
//
int batchCount,
int64_t
strideA
,
//
int64_t strideA,
int64_t
strideB
,
//
int64_t strideB,
int64_t
strideD
)
//
int64_t strideD)
{
//
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
//
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cudaEvent_t
startEvent
;
//
cudaEvent_t startEvent;
cudaEvent_t
stopEvent
;
//
cudaEvent_t stopEvent;
CublasDataType
data_type
;
//
CublasDataType data_type;
cublasLtMatmulDesc_t
operationDesc
=
NULL
;
//
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t
Adesc
=
NULL
,
Bdesc
=
NULL
,
Cdesc
=
NULL
,
Ddesc
=
NULL
;
//
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL;
cudaStream_t
stream
=
0
;
//
cudaStream_t stream = 0;
// SplitK value that we are going to try when SplitK is supported for a
//
// SplitK value that we are going to try when SplitK is supported for a
// given algo
//
// given algo
const
int
splitKSequenceA
[]
=
{
2
,
3
,
4
,
5
,
6
,
8
,
12
,
16
,
32
};
//
const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32};
// Let try a fixed number of combinations
//
// Let try a fixed number of combinations
int
AlgoCount
=
0
;
//
int AlgoCount = 0;
int
AlgoCountRestrict
=
0
;
// workspace == 0
//
int AlgoCountRestrict = 0; // workspace == 0
const
int
maxNumTraversal
=
50
;
// max number of traversal
//
const int maxNumTraversal = 50; // max number of traversal
std
::
vector
<
cublasLtMatmulAlgo_t
>
algos
(
AlgoCombinations
);
// 0 <= workspace <= 32MB
//
std::vector<cublasLtMatmulAlgo_t> algos(AlgoCombinations); // 0 <= workspace <= 32MB
std
::
vector
<
cublasLtMatmulAlgo_t
>
algosRestrict
(
AlgoCombinations
);
// workspace == 0
//
std::vector<cublasLtMatmulAlgo_t> algosRestrict(AlgoCombinations); // workspace == 0
const
int
kernelRepeats
=
100
;
// number of time the CUDA kernels will be run back to back
//
const int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
int
nbAlgoIds
=
0
;
// Number of algorithms actually returned by
//
int nbAlgoIds = 0; // Number of algorithms actually returned by
// cublasLtMatmulAlgoGetIds function.
//
// cublasLtMatmulAlgoGetIds function.
#define ALGO_IDS 100 // Number of algorithms requested.
//
#define ALGO_IDS 100 // Number of algorithms requested.
int
algoIdA
[
ALGO_IDS
];
// Array containing the algorithm IDs returned by
//
int algoIdA[ALGO_IDS]; // Array containing the algorithm IDs returned by
// cublasLtMatmulAlgoGetIds function.
//
// cublasLtMatmulAlgoGetIds function.
cudaDataType_t
Atype
,
Btype
,
Ctype
,
scaleType
,
Dtype
;
//
cudaDataType_t Atype, Btype, Ctype, scaleType, Dtype;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
cublasComputeType_t
computeType
;
// //
cublasComputeType_t computeType;
#else
// //
#else
cudaDataType_t
computeType
;
//
cudaDataType_t computeType;
#endif
// //
#endif
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
//
if (std::is_same<T, float>::value) {
data_type
=
FLOAT_DATATYPE
;
//
data_type = FLOAT_DATATYPE;
Atype
=
CUDA_R_32F
,
Btype
=
CUDA_R_32F
,
Ctype
=
CUDA_R_32F
,
Dtype
=
CUDA_R_32F
;
//
Atype = CUDA_R_32F, Btype = CUDA_R_32F, Ctype = CUDA_R_32F, Dtype = CUDA_R_32F;
}
//
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
//
else if (std::is_same<T, half>::value) {
data_type
=
HALF_DATATYPE
;
//
data_type = HALF_DATATYPE;
Atype
=
CUDA_R_16F
,
Btype
=
CUDA_R_16F
,
Ctype
=
CUDA_R_16F
,
Dtype
=
CUDA_R_16F
;
//
Atype = CUDA_R_16F, Btype = CUDA_R_16F, Ctype = CUDA_R_16F, Dtype = CUDA_R_16F;
}
//
}
#ifdef ENABLE_BF16
//
#ifdef ENABLE_BF16
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
//
else if (std::is_same<T, __nv_bfloat16>::value) {
data_type
=
BFLOAT16_DATATYPE
;
//
data_type = BFLOAT16_DATATYPE;
Atype
=
CUDA_R_16BF
,
Btype
=
CUDA_R_16BF
,
Ctype
=
CUDA_R_16BF
,
Dtype
=
CUDA_R_16BF
;
//
Atype = CUDA_R_16BF, Btype = CUDA_R_16BF, Ctype = CUDA_R_16BF, Dtype = CUDA_R_16BF;
}
//
}
#endif
//
#endif
#ifdef ENABLE_FP8
//
#ifdef ENABLE_FP8
else
if
(
std
::
is_same
<
T
,
__nv_fp8_e4m3
>::
value
)
{
//
else if (std::is_same<T, __nv_fp8_e4m3>::value) {
data_type
=
FP8_DATATYPE
;
//
data_type = FP8_DATATYPE;
Atype
=
CUDA_R_8F_E4M3
,
Btype
=
CUDA_R_8F_E4M3
,
Ctype
=
CUDA_R_16BF
;
//
Atype = CUDA_R_8F_E4M3, Btype = CUDA_R_8F_E4M3, Ctype = CUDA_R_16BF;
#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE
//
#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE
Dtype
=
CUDA_R_16BF
;
//
Dtype = CUDA_R_16BF;
#else
//
#else
Dtype
=
dtype_fp8
;
//
Dtype = dtype_fp8;
#endif
//
#endif
}
//
}
#endif
//
#endif
if
(
sizeof
(
scaleT
)
==
sizeof
(
float
))
{
//
if (sizeof(scaleT) == sizeof(float)) {
scaleType
=
CUDA_R_32F
;
//
scaleType = CUDA_R_32F;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
computeType
=
CUBLAS_COMPUTE_32F
;
// //
computeType = CUBLAS_COMPUTE_32F;
#else
// //
#else
computeType
=
CUDA_R_32F
;
//
computeType = CUDA_R_32F;
#endif
// //
#endif
}
//
}
else
{
//
else {
scaleType
=
CUDA_R_16F
;
//
scaleType = CUDA_R_16F;
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
computeType
=
CUBLAS_COMPUTE_16F
;
// //
computeType = CUBLAS_COMPUTE_16F;
#else
// //
#else
computeType
=
CUDA_R_16F
;
//
computeType = CUDA_R_16F;
#endif
// //
#endif
}
//
}
const
cublasOperation_t
tA
=
data_type
==
FP8_DATATYPE
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
//
const cublasOperation_t tA = data_type == FP8_DATATYPE ? CUBLAS_OP_T : CUBLAS_OP_N;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t for
//
// Create operation descriptor; see cublasLtMatmulDescAttributes_t for
// details about defaults; here we just need to set the transforms for A and
//
// details about defaults; here we just need to set the transforms for A and
// B
//
// B
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
status
=
cublasLtMatmulDescCreate
(
&
operationDesc
,
computeType
,
// //
status = cublasLtMatmulDescCreate(&operationDesc, computeType,
scaleType
);
// creates a matrix multiply descriptor
// //
scaleType); // creates a matrix multiply descriptor
#else
// //
#else
status
=
cublasLtMatmulDescCreate
(
&
operationDesc
,
computeType
);
//
status = cublasLtMatmulDescCreate(&operationDesc, computeType);
#endif
// //
#endif
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
tA
,
sizeof
(
tA
));
//
status = cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
#ifdef ENABLE_FP8
//
#ifdef ENABLE_FP8
if
(
data_type
==
FP8_DATATYPE
)
{
//
if (data_type == FP8_DATATYPE) {
const
int8_t
fastAccuMode
=
1
;
// enable fast imprecise accum
//
const int8_t fastAccuMode = 1; // enable fast imprecise accum
status
=
cublasLtMatmulDescSetAttribute
(
//
status = cublasLtMatmulDescSetAttribute(
operationDesc
,
CUBLASLT_MATMUL_DESC_FAST_ACCUM
,
&
fastAccuMode
,
sizeof
(
decltype
(
fastAccuMode
)));
//
operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(decltype(fastAccuMode)));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
}
//
}
#endif
//
#endif
// Create matrix descriptors. We are good with the details here so no need
//
// Create matrix descriptors. We are good with the details here so no need
// to set any extra attributes
//
// to set any extra attributes
if
(
data_type
==
FP8_DATATYPE
)
{
//
if (data_type == FP8_DATATYPE) {
status
=
cublasLtMatrixLayoutCreate
(
&
Adesc
,
Atype
,
k
,
m
,
k
);
//
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, k, m, k);
}
//
}
else
{
//
else {
status
=
cublasLtMatrixLayoutCreate
(
&
Adesc
,
Atype
,
m
,
k
,
m
);
//
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, m);
}
//
}
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
Btype
,
k
,
n
,
k
);
//
status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, k, n, k);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
Ctype
,
m
,
n
,
m
);
//
status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, m);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
status
=
cublasLtMatrixLayoutCreate
(
&
Ddesc
,
Dtype
,
m
,
n
,
m
);
//
status = cublasLtMatrixLayoutCreate(&Ddesc, Dtype, m, n, m);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
if
(
batchCount
>
1
)
{
//
if (batchCount > 1) {
check_cuda_error
(
cublasLtMatrixLayoutSetAttribute
(
//
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Adesc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batchCount
,
sizeof
(
batchCount
)));
//
Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error
(
cublasLtMatrixLayoutSetAttribute
(
//
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batchCount
,
sizeof
(
batchCount
)));
//
Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error
(
cublasLtMatrixLayoutSetAttribute
(
//
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batchCount
,
sizeof
(
batchCount
)));
//
Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error
(
cublasLtMatrixLayoutSetAttribute
(
//
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Ddesc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batchCount
,
sizeof
(
batchCount
)));
//
Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error
(
cublasLtMatrixLayoutSetAttribute
(
//
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Adesc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideA
,
sizeof
(
strideA
)));
//
Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA)));
check_cuda_error
(
cublasLtMatrixLayoutSetAttribute
(
//
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideB
,
sizeof
(
strideB
)));
//
Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB)));
check_cuda_error
(
cublasLtMatrixLayoutSetAttribute
(
//
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideD
,
sizeof
(
strideD
)));
//
Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD)));
check_cuda_error
(
cublasLtMatrixLayoutSetAttribute
(
//
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Ddesc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideD
,
sizeof
(
strideD
)));
//
Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD)));
}
//
}
// Create CUDA event to time the execution time of each algo
//
// Create CUDA event to time the execution time of each algo
if
(
cudaEventCreate
(
&
startEvent
,
cudaEventBlockingSync
)
!=
cudaSuccess
)
{
//
if (cudaEventCreate(&startEvent, cudaEventBlockingSync) != cudaSuccess) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
if
(
cudaEventCreate
(
&
stopEvent
,
cudaEventBlockingSync
)
!=
cudaSuccess
)
{
//
if (cudaEventCreate(&stopEvent, cudaEventBlockingSync) != cudaSuccess) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
// Request the 100 first AlgoId available
//
// Request the 100 first AlgoId available
status
=
cublasLtMatmulAlgoGetIds
(
//
status = cublasLtMatmulAlgoGetIds(
ltHandle
,
computeType
,
scaleType
,
Atype
,
Btype
,
Ctype
,
Dtype
,
ALGO_IDS
,
algoIdA
,
&
nbAlgoIds
);
//
ltHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, ALGO_IDS, algoIdA, &nbAlgoIds);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
goto
CLEANUP
;
//
goto CLEANUP;
}
//
}
if
(
nbAlgoIds
>
ALGO_IDS
)
{
//
if (nbAlgoIds > ALGO_IDS) {
printf
(
//
printf(
"Warning: the algo id count is not large enough to guarantee the best algo %d, %d
\n
"
,
nbAlgoIds
,
ALGO_IDS
);
//
"Warning: the algo id count is not large enough to guarantee the best algo %d, %d\n", nbAlgoIds, ALGO_IDS);
}
//
}
// Loop over the Algo IDs
//
// Loop over the Algo IDs
// This loop doesn't work for fp8 gemm
//
// This loop doesn't work for fp8 gemm
for
(
int
idx
=
0
;
(
idx
<
nbAlgoIds
)
&&
(
AlgoCount
<
AlgoCombinations
);
idx
++
)
{
//
for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) {
cublasLtMatmulAlgo_t
algo
;
//
cublasLtMatmulAlgo_t algo;
size_t
sizeWritten
=
0
;
//
size_t sizeWritten = 0;
/* Initialize algo structure with given Algp ID */
//
/* Initialize algo structure with given Algp ID */
status
=
//
status =
cublasLtMatmulAlgoInit
(
ltHandle
,
computeType
,
scaleType
,
Atype
,
Btype
,
Ctype
,
Dtype
,
algoIdA
[
idx
],
&
algo
);
//
cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, algoIdA[idx], &algo);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
//
if (status != CUBLAS_STATUS_SUCCESS) {
continue
;
//
continue;
}
//
}
// Query the tiles enums supported by that algo
//
// Query the tiles enums supported by that algo
cublasLtMatmulAlgoCapGetAttribute
(
&
algo
,
CUBLASLT_ALGO_CAP_TILE_IDS
,
NULL
,
0
,
&
sizeWritten
);
//
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten);
int
nbTiles
=
int
(
sizeWritten
/
sizeof
(
int
));
//
int nbTiles = int(sizeWritten / sizeof(int));
int
*
tileA
=
new
int
[
nbTiles
==
0
?
1
:
nbTiles
];
//
int* tileA = new int[nbTiles == 0 ? 1 : nbTiles];
if
(
nbTiles
==
0
)
{
//
if (nbTiles == 0) {
tileA
[
0
]
=
CUBLASLT_MATMUL_TILE_UNDEFINED
;
//
tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED;
nbTiles
=
1
;
//
nbTiles = 1;
}
//
}
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoCapGetAttribute
(
&
algo
,
CUBLASLT_ALGO_CAP_STAGES_IDS
,
NULL
,
0
,
&
sizeWritten
);
// //
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten);
int
nbStages
=
int
(
sizeWritten
/
sizeof
(
int
));
// //
int nbStages = int(sizeWritten / sizeof(int));
std
::
vector
<
int
>
stagesA
(
nbStages
==
0
?
1
:
nbStages
);
// //
std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages);
if
(
nbStages
==
0
)
{
// //
if (nbStages == 0) {
stagesA
[
0
]
=
CUBLASLT_MATMUL_STAGES_UNDEFINED
;
// //
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages
=
1
;
// //
nbStages = 1;
}
// //
}
else
{
// //
else {
cublasLtMatmulAlgoCapGetAttribute
(
// //
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_STAGES_IDS
,
stagesA
.
data
(),
sizeof
(
int
)
*
nbStages
,
&
sizeWritten
);
// //
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten);
}
// //
}
#endif
// //
#endif
int
splitkSupport
,
redMask
,
swizzlingMax
,
customOptionMax
;
//
int splitkSupport, redMask, swizzlingMax, customOptionMax;
// Retrieve Algo Capabilities attributes to be able to setup loop over
//
// Retrieve Algo Capabilities attributes to be able to setup loop over
// the different combinations
//
// the different combinations
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_TILE_IDS
,
tileA
,
sizeof
(
int
)
*
nbTiles
,
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_SPLITK_SUPPORT
,
&
splitkSupport
,
sizeof
(
splitkSupport
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK
,
&
redMask
,
sizeof
(
redMask
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT
,
&
swizzlingMax
,
sizeof
(
swizzlingMax
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute
(
//
cublasLtMatmulAlgoCapGetAttribute(
&
algo
,
CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX
,
&
customOptionMax
,
sizeof
(
customOptionMax
),
&
sizeWritten
);
//
&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten);
/* Loop over the different tiles */
//
/* Loop over the different tiles */
for
(
int
tileIdx
=
0
;
tileIdx
<
nbTiles
;
tileIdx
++
)
{
//
for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) {
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
make:q
/* Loop over different stages count */
// //
/* Loop over different stages count */
for
(
int
stagesIdx
=
0
;
stagesIdx
<
nbStages
;
stagesIdx
++
)
{
// //
for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) {
cublasLtMatmulAlgoConfigSetAttribute
(
// //
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
stagesA
[
stagesIdx
],
sizeof
(
stagesA
[
stagesIdx
]));
// //
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]));
#endif
// //
#endif
/* Loop over the different custom option if any */
//
/* Loop over the different custom option if any */
for
(
int
customOption
=
0
;
customOption
<=
customOptionMax
;
customOption
++
)
{
//
for (int customOption = 0; customOption <= customOptionMax; customOption++) {
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
customOption
,
sizeof
(
customOption
));
//
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption));
/* Loop over the CTAs swizzling support */
//
/* Loop over the CTAs swizzling support */
for
(
int
k
=
0
;
k
<=
swizzlingMax
;
k
++
)
{
//
for (int k = 0; k <= swizzlingMax; k++) {
int
splitK_trial
=
0
;
//
int splitK_trial = 0;
if
(
splitkSupport
)
{
//
if (splitkSupport) {
splitK_trial
+=
sizeof
(
splitKSequenceA
)
/
sizeof
(
splitKSequenceA
[
0
]);
//
splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
}
//
}
// Loop over the splitK value over a fixed sequence
//
// Loop over the splitK value over a fixed sequence
// splitKSequenceA in addition to the case where splitK
//
// splitKSequenceA in addition to the case where splitK
// is not enabled
//
// is not enabled
for
(
int
l
=
0
;
(
l
<
(
1
+
splitK_trial
))
&&
(
AlgoCount
<
AlgoCombinations
);
l
++
)
{
//
for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) {
/* Setup attribute of the algo to run */
//
/* Setup attribute of the algo to run */
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_TILE_ID
,
&
tileA
[
tileIdx
],
sizeof
(
tileA
[
tileIdx
]));
//
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]));
int
splitK_val
=
0
;
//
int splitK_val = 0;
int
redScheme
=
CUBLASLT_REDUCTION_SCHEME_NONE
;
//
int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
&
splitK_val
,
sizeof
(
splitK_val
));
//
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val));
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
k
,
sizeof
(
k
));
//
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k));
cublasLtMatmulAlgoConfigSetAttribute
(
//
cublasLtMatmulAlgoConfigSetAttribute(
&
algo
,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
&
redScheme
,
sizeof
(
int
));
//
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int));
if
(
l
>
0
)
{
// Split-K case
//
if (l > 0) { // Split-K case
splitK_val
=
splitKSequenceA
[
l
-
1
];
//
splitK_val = splitKSequenceA[l - 1];
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
//
cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
//
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&
splitKSequenceA
[
l
-
1
],
//
&splitKSequenceA[l - 1],
sizeof
(
splitKSequenceA
[
l
-
1
]));
//
sizeof(splitKSequenceA[l - 1]));
/* Going over all the reduction scheme */
//
/* Going over all the reduction scheme */
for
(
redScheme
=
1
;
//
for (redScheme = 1;
redScheme
<
(
int
)
CUBLASLT_REDUCTION_SCHEME_MASK
&&
(
AlgoCount
<
AlgoCombinations
);
//
redScheme < (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations);
redScheme
=
redScheme
<<
1
)
{
//
redScheme = redScheme << 1) {
if
(
redScheme
&
redMask
)
{
//
if (redScheme & redMask) {
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
//
cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
//
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&
redScheme
,
//
&redScheme,
sizeof
(
redScheme
));
//
sizeof(redScheme));
cublasLtMatmulHeuristicResult_t
heurResult
;
//
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t
algoStatus
=
cublasLtMatmulAlgoCheck
(
//
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
ltHandle
,
operationDesc
,
Adesc
,
Bdesc
,
Cdesc
,
Cdesc
,
&
algo
,
&
heurResult
);
//
ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult);
if
(
heurResult
.
workspaceSize
>
workSpaceSize
)
{
//
if (heurResult.workspaceSize > workSpaceSize) {
// printf("not enough workspace!
//
// printf("not enough workspace!
// %ld\n",
//
// %ld\n",
// heurResult.workspaceSize);
//
// heurResult.workspaceSize);
algoStatus
=
CUBLAS_STATUS_NOT_SUPPORTED
;
// Not enough workspace
//
algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
}
//
}
else
if
(
heurResult
.
workspaceSize
==
0
)
{
//
else if (heurResult.workspaceSize == 0) {
if
(
algoStatus
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algosRestrict
[
AlgoCountRestrict
++
]
=
algo
;
//
algosRestrict[AlgoCountRestrict++] = algo;
}
//
}
}
//
}
if
(
algoStatus
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algos
[
AlgoCount
++
]
=
algo
;
//
algos[AlgoCount++] = algo;
}
//
}
}
// end if
//
} // end if
}
// end for
//
} // end for
}
//
}
else
{
// Non-splitK case
//
else { // Non-splitK case
/* if user preference is ok with workspace */
//
/* if user preference is ok with workspace */
if
(
AlgoCount
<
AlgoCombinations
)
{
//
if (AlgoCount < AlgoCombinations) {
cublasLtMatmulHeuristicResult_t
heurResult
;
//
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t
algoStatus
=
cublasLtMatmulAlgoCheck
(
//
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
ltHandle
,
operationDesc
,
Adesc
,
Bdesc
,
Cdesc
,
Cdesc
,
&
algo
,
&
heurResult
);
//
ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult);
if
(
heurResult
.
workspaceSize
>
workSpaceSize
)
{
//
if (heurResult.workspaceSize > workSpaceSize) {
// printf("not enough workspace! %ld\n",
//
// printf("not enough workspace! %ld\n",
// heurResult.workspaceSize);
//
// heurResult.workspaceSize);
algoStatus
=
CUBLAS_STATUS_NOT_SUPPORTED
;
// Not
//
algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not
// enough
//
// enough
// workspace
//
// workspace
}
//
}
else
if
(
heurResult
.
workspaceSize
==
0
)
{
//
else if (heurResult.workspaceSize == 0) {
if
(
algoStatus
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algosRestrict
[
AlgoCountRestrict
++
]
=
algo
;
//
algosRestrict[AlgoCountRestrict++] = algo;
}
//
}
}
//
}
if
(
algoStatus
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algos
[
AlgoCount
++
]
=
algo
;
//
algos[AlgoCount++] = algo;
}
//
}
}
//
}
}
//
}
}
// end l
//
} // end l
}
// end k
//
} // end k
}
// end customOption
//
} // end customOption
#if (CUDART_VERSION >= 11000)
// //
#if (CUDART_VERSION >= 11000)
}
// end stagesIdx
//
} // end stagesIdx
#endif
// //
#endif
}
// end tileIdx
//
} // end tileIdx
delete
[]
tileA
;
//
delete[] tileA;
}
// end idx
//
} // end idx
printf
(
"AlgoCount: %d
\n
"
,
AlgoCount
);
//
printf("AlgoCount: %d\n", AlgoCount);
if
(
data_type
==
FP8_DATATYPE
)
{
//
if (data_type == FP8_DATATYPE) {
assert
(
AlgoCount
==
0
);
//
assert(AlgoCount == 0);
}
//
}
if
(
AlgoCount
<
maxNumTraversal
&&
data_type
!=
FP8_DATATYPE
)
{
//
if (AlgoCount < maxNumTraversal && data_type != FP8_DATATYPE) {
// 0 <= workspacesize <= 32MB
//
// 0 <= workspacesize <= 32MB
for
(
int
i
=
0
;
i
<
AlgoCount
;
i
++
)
{
//
for (int i = 0; i < AlgoCount; i++) {
status
=
customMatmulRun
(
ltHandle
,
//
status = customMatmulRun(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
/* host or device pointer */
//
alpha, /* host or device pointer */
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
/* host or device pointer */
//
beta, /* host or device pointer */
C
,
//
C,
Cdesc
,
//
Cdesc,
C
,
//
C,
Cdesc
,
//
Cdesc,
algos
[
i
],
//
algos[i],
kernelRepeats
,
//
kernelRepeats,
workSpace
,
//
workSpace,
workSpaceSize
,
//
workSpaceSize,
perfResults
[
i
],
//
perfResults[i],
stream
,
//
stream,
startEvent
,
//
startEvent,
stopEvent
);
//
stopEvent);
perfResults
[
i
].
status
=
status
;
//
perfResults[i].status = status;
// if (status == CUBLAS_STATUS_SUCCESS) AlgoCount++;
//
// if (status == CUBLAS_STATUS_SUCCESS) AlgoCount++;
}
//
}
}
//
}
else
{
//
else {
// Heuristic + workspacesize==0
//
// Heuristic + workspacesize==0
AlgoCount
=
0
;
//
AlgoCount = 0;
nbAlgoIds
=
0
;
//
nbAlgoIds = 0;
cublasLtMatmulPreference_t
pref
;
//
cublasLtMatmulPreference_t pref;
cublasLtMatmulPreferenceCreate
(
&
pref
);
//
cublasLtMatmulPreferenceCreate(&pref);
uint64_t
maxWorkSpaceSize
=
workSpaceSize
;
//(32MB)
//
uint64_t maxWorkSpaceSize = workSpaceSize; //(32MB)
cublasLtMatmulPreferenceSetAttribute
(
//
cublasLtMatmulPreferenceSetAttribute(
pref
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
maxWorkSpaceSize
,
sizeof
(
maxWorkSpaceSize
));
//
pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &maxWorkSpaceSize, sizeof(maxWorkSpaceSize));
cublasLtMatmulHeuristicResult_t
heuristicResultsArray
[
maxNumTraversal
];
//
cublasLtMatmulHeuristicResult_t heuristicResultsArray[maxNumTraversal];
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
//
cublasLtMatmulAlgoGetHeuristic(ltHandle,
operationDesc
,
//
operationDesc,
Adesc
,
//
Adesc,
Bdesc
,
//
Bdesc,
Cdesc
,
//
Cdesc,
Ddesc
,
//
Ddesc,
pref
,
//
pref,
maxNumTraversal
,
//
maxNumTraversal,
heuristicResultsArray
,
//
heuristicResultsArray,
&
nbAlgoIds
);
//
&nbAlgoIds);
cublasLtMatmulPreferenceDestroy
(
pref
);
//
cublasLtMatmulPreferenceDestroy(pref);
printf
(
"return %d and run heuristic algo
\n
"
,
nbAlgoIds
);
//
printf("return %d and run heuristic algo\n", nbAlgoIds);
for
(
int
i
=
0
;
i
<
nbAlgoIds
;
i
++
)
{
//
for (int i = 0; i < nbAlgoIds; i++) {
if
(
heuristicResultsArray
[
i
].
state
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (heuristicResultsArray[i].state == CUBLAS_STATUS_SUCCESS) {
status
=
customMatmulRun
(
ltHandle
,
//
status = customMatmulRun(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
/* host or device pointer */
//
alpha, /* host or device pointer */
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
/* host or device pointer */
//
beta, /* host or device pointer */
C
,
//
C,
Cdesc
,
//
Cdesc,
C
,
//
C,
Ddesc
,
//
Ddesc,
heuristicResultsArray
[
i
].
algo
,
//
heuristicResultsArray[i].algo,
kernelRepeats
,
//
kernelRepeats,
workSpace
,
//
workSpace,
workSpaceSize
,
//
workSpaceSize,
perfResults
[
AlgoCount
],
//
perfResults[AlgoCount],
stream
,
//
stream,
startEvent
,
//
startEvent,
stopEvent
);
//
stopEvent);
perfResults
[
AlgoCount
].
status
=
status
;
//
perfResults[AlgoCount].status = status;
if
(
status
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount
++
;
//
AlgoCount++;
}
//
}
}
//
}
}
//
}
// workspacesize==0
//
// workspacesize==0
printf
(
"workspacesize==0, run %d algos
\n
"
,
AlgoCountRestrict
);
//
printf("workspacesize==0, run %d algos\n", AlgoCountRestrict);
for
(
int
i
=
0
;
i
<
AlgoCountRestrict
&&
i
<
(
maxNumTraversal
-
nbAlgoIds
);
i
++
)
{
//
for (int i = 0; i < AlgoCountRestrict && i < (maxNumTraversal - nbAlgoIds); i++) {
status
=
customMatmulRun
(
ltHandle
,
//
status = customMatmulRun(ltHandle,
operationDesc
,
//
operationDesc,
alpha
,
/* host or device pointer */
//
alpha, /* host or device pointer */
A
,
//
A,
Adesc
,
//
Adesc,
B
,
//
B,
Bdesc
,
//
Bdesc,
beta
,
/* host or device pointer */
//
beta, /* host or device pointer */
C
,
//
C,
Cdesc
,
//
Cdesc,
C
,
//
C,
Ddesc
,
//
Ddesc,
algosRestrict
[
i
],
//
algosRestrict[i],
kernelRepeats
,
//
kernelRepeats,
NULL
,
//
NULL,
0
,
//
0,
perfResults
[
AlgoCount
],
//
perfResults[AlgoCount],
stream
,
//
stream,
startEvent
,
//
startEvent,
stopEvent
);
//
stopEvent);
perfResults
[
AlgoCount
].
status
=
status
;
//
perfResults[AlgoCount].status = status;
if
(
status
==
CUBLAS_STATUS_SUCCESS
)
{
//
if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount
++
;
//
AlgoCount++;
}
//
}
}
//
}
}
//
}
// Sort the results per run duration
//
// Sort the results per run duration
std
::
sort
(
perfResults
,
perfResults
+
AlgoCount
,
time_compare
);
//
std::sort(perfResults, perfResults + AlgoCount, time_compare);
// Print timing and perf details
//
// Print timing and perf details
for
(
int
i
=
0
,
hasPrint
=
1
;
i
<
AlgoCount
;
i
++
)
{
//
for (int i = 0, hasPrint = 1; i < AlgoCount; i++) {
printf
(
"result %03d : "
,
i
);
//
printf("result %03d : ", i);
hasPrint
=
printPerfStructure
(
batch_size
,
//
hasPrint = printPerfStructure(batch_size,
seq_len
,
//
seq_len,
head_num
,
//
head_num,
size_per_head
,
//
size_per_head,
m
,
//
m,
n
,
//
n,
k
,
//
k,
perfResults
[
i
],
//
perfResults[i],
fout
,
//
fout,
data_type
,
//
data_type,
hasPrint
,
//
hasPrint,
batchCount
);
//
batchCount);
}
//
}
CLEANUP:
//
CLEANUP:
// Descriptors are no longer needed as all GPU work was already enqueued
//
// Descriptors are no longer needed as all GPU work was already enqueued
if
(
Cdesc
)
{
//
if (Cdesc) {
cublasLtMatrixLayoutDestroy
(
Cdesc
);
//
cublasLtMatrixLayoutDestroy(Cdesc);
}
//
}
if
(
Bdesc
)
{
//
if (Bdesc) {
cublasLtMatrixLayoutDestroy
(
Bdesc
);
//
cublasLtMatrixLayoutDestroy(Bdesc);
}
//
}
if
(
Adesc
)
{
//
if (Adesc) {
cublasLtMatrixLayoutDestroy
(
Adesc
);
//
cublasLtMatrixLayoutDestroy(Adesc);
}
//
}
if
(
operationDesc
)
{
//
if (operationDesc) {
cublasLtMatmulDescDestroy
(
operationDesc
);
//
cublasLtMatmulDescDestroy(operationDesc);
}
//
}
if
(
startEvent
)
{
//
if (startEvent) {
cudaEventDestroy
(
startEvent
);
//
cudaEventDestroy(startEvent);
}
//
}
if
(
stopEvent
)
{
//
if (stopEvent) {
cudaEventDestroy
(
stopEvent
);
//
cudaEventDestroy(stopEvent);
}
//
}
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
//
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
//
}
template
int
LtHgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int
batch_size
,
//
int batch_size,
int
seq_len
,
//
int seq_len,
int
head_num
,
//
int head_num,
int
size_per_head
,
//
int size_per_head,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
float
*
alpha
,
/* host pointer */
//
const float* alpha, /* host pointer */
const
float
*
A
,
//
const float* A,
const
float
*
B
,
//
const float* B,
const
float
*
beta
,
/* host pointer */
//
const float* beta, /* host pointer */
float
*
C
,
//
float* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
,
//
FILE* fout,
customMatmulPerf_t
perfResults
[],
//
customMatmulPerf_t perfResults[],
int
AlgoCombinations
,
//
int AlgoCombinations,
cudaDataType_t
dtype_fp8
,
//
cudaDataType_t dtype_fp8,
int
batchCount
,
//
int batchCount,
int64_t
strideA
,
//
int64_t strideA,
int64_t
strideB
,
//
int64_t strideB,
int64_t
strideD
);
//
int64_t strideD);
template
int
LtHgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int
batch_size
,
//
int batch_size,
int
seq_len
,
//
int seq_len,
int
head_num
,
//
int head_num,
int
size_per_head
,
//
int size_per_head,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
half
*
alpha
,
/* host pointer */
//
const half* alpha, /* host pointer */
const
half
*
A
,
//
const half* A,
const
half
*
B
,
//
const half* B,
const
half
*
beta
,
/* host pointer */
//
const half* beta, /* host pointer */
half
*
C
,
//
half* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
,
//
FILE* fout,
customMatmulPerf_t
perfResults
[],
//
customMatmulPerf_t perfResults[],
int
AlgoCombinations
,
//
int AlgoCombinations,
cudaDataType_t
dtype_fp8
,
//
cudaDataType_t dtype_fp8,
int
batchCount
,
//
int batchCount,
int64_t
strideA
,
//
int64_t strideA,
int64_t
strideB
,
//
int64_t strideB,
int64_t
strideD
);
//
int64_t strideD);
#ifdef ENABLE_BF16
//
#ifdef ENABLE_BF16
template
int
LtHgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int
batch_size
,
//
int batch_size,
int
seq_len
,
//
int seq_len,
int
head_num
,
//
int head_num,
int
size_per_head
,
//
int size_per_head,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
float
*
alpha
,
/* host pointer */
//
const float* alpha, /* host pointer */
const
__nv_bfloat16
*
A
,
//
const __nv_bfloat16* A,
const
__nv_bfloat16
*
B
,
//
const __nv_bfloat16* B,
const
float
*
beta
,
/* host pointer */
//
const float* beta, /* host pointer */
__nv_bfloat16
*
C
,
//
__nv_bfloat16* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
,
//
FILE* fout,
customMatmulPerf_t
perfResults
[],
//
customMatmulPerf_t perfResults[],
int
AlgoCombinations
,
//
int AlgoCombinations,
cudaDataType_t
dtype_fp8
,
//
cudaDataType_t dtype_fp8,
int
batchCount
,
//
int batchCount,
int64_t
strideA
,
//
int64_t strideA,
int64_t
strideB
,
//
int64_t strideB,
int64_t
strideD
);
//
int64_t strideD);
#endif
//
#endif
#ifdef ENABLE_FP8
//
#ifdef ENABLE_FP8
template
int
LtHgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int
batch_size
,
//
int batch_size,
int
seq_len
,
//
int seq_len,
int
head_num
,
//
int head_num,
int
size_per_head
,
//
int size_per_head,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
float
*
alpha
,
/* host pointer */
//
const float* alpha, /* host pointer */
const
__nv_fp8_e4m3
*
A
,
//
const __nv_fp8_e4m3* A,
const
__nv_fp8_e4m3
*
B
,
//
const __nv_fp8_e4m3* B,
const
float
*
beta
,
/* host pointer */
//
const float* beta, /* host pointer */
__nv_fp8_e4m3
*
C
,
//
__nv_fp8_e4m3* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
,
//
FILE* fout,
customMatmulPerf_t
perfResults
[],
//
customMatmulPerf_t perfResults[],
int
AlgoCombinations
,
//
int AlgoCombinations,
cudaDataType_t
dtype_fp8
,
//
cudaDataType_t dtype_fp8,
int
batchCount
,
//
int batchCount,
int64_t
strideA
,
//
int64_t strideA,
int64_t
strideB
,
//
int64_t strideB,
int64_t
strideD
);
//
int64_t strideD);
#endif
//
#endif
template
int
LtHgemmCustomFind
(
cublasLtHandle_t
ltHandle
,
//
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int
batch_size
,
//
int batch_size,
int
seq_len
,
//
int seq_len,
int
head_num
,
//
int head_num,
int
size_per_head
,
//
int size_per_head,
int
m
,
//
int m,
int
n
,
//
int n,
int
k
,
//
int k,
const
float
*
alpha
,
/* host pointer */
//
const float* alpha, /* host pointer */
const
half
*
A
,
//
const half* A,
const
half
*
B
,
//
const half* B,
const
float
*
beta
,
/* host pointer */
//
const float* beta, /* host pointer */
half
*
C
,
//
half* C,
void
*
workSpace
,
//
void* workSpace,
size_t
workSpaceSize
,
//
size_t workSpaceSize,
FILE
*
fout
,
//
FILE* fout,
customMatmulPerf_t
perfResults
[],
//
customMatmulPerf_t perfResults[],
int
AlgoCombinations
,
//
int AlgoCombinations,
cudaDataType_t
dtype_fp8
,
//
cudaDataType_t dtype_fp8,
int
batchCount
,
//
int batchCount,
int64_t
strideA
,
//
int64_t strideA,
int64_t
strideB
,
//
int64_t strideB,
int64_t
strideD
);
//
int64_t strideD);
size_t
calGemmTestBufSizeInByte
(
int
batch_size
,
size_t
calGemmTestBufSizeInByte
(
int
batch_size
,
int
seq_len
,
int
seq_len
,
...
...
src/turbomind/utils/gemm_test/gpt_gemm_func.cc
View file @
9484fd1c
...
@@ -223,8 +223,8 @@ void generate_gpt_gemm_config(int batch_size,
...
@@ -223,8 +223,8 @@ void generate_gpt_gemm_config(int batch_size,
cublasHandle_t
cublas_handle
;
cublasHandle_t
cublas_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
cublasLtHandle_t
ltHandle
;
//
cublasLtHandle_t ltHandle;
check_cuda_error
(
cublasLtCreate
(
&
ltHandle
));
//
check_cuda_error(cublasLtCreate(<Handle));
cudaDataType_t
AType
;
cudaDataType_t
AType
;
cudaDataType_t
BType
;
cudaDataType_t
BType
;
...
@@ -244,7 +244,8 @@ void generate_gpt_gemm_config(int batch_size,
...
@@ -244,7 +244,8 @@ void generate_gpt_gemm_config(int batch_size,
DType
=
CUDA_R_32F
;
DType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO23
;
// endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
data_type
=
HALF_DATATYPE
;
data_type
=
HALF_DATATYPE
;
...
@@ -252,9 +253,11 @@ void generate_gpt_gemm_config(int batch_size,
...
@@ -252,9 +253,11 @@ void generate_gpt_gemm_config(int batch_size,
BType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
DType
=
CUDA_R_16F
;
DType
=
CUDA_R_16F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_16F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
...
@@ -264,8 +267,10 @@ void generate_gpt_gemm_config(int batch_size,
...
@@ -264,8 +267,10 @@ void generate_gpt_gemm_config(int batch_size,
CType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
DType
=
CUDA_R_16BF
;
DType
=
CUDA_R_16BF
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#endif
#endif
#ifdef ENABLE_FP8
#ifdef ENABLE_FP8
...
@@ -293,12 +298,24 @@ void generate_gpt_gemm_config(int batch_size,
...
@@ -293,12 +298,24 @@ void generate_gpt_gemm_config(int batch_size,
DType_FP8
[
9
]
=
CUDA_R_16BF
;
DType_FP8
[
9
]
=
CUDA_R_16BF
;
#endif
#endif
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#endif
#endif
float
alpha
=
(
float
)
1.0
f
;
// float alpha = (float)1.0f;
float
beta
=
(
float
)
0.0
f
;
// float beta = (float)0.0f;
float
f_alpha
=
(
float
)
1.0
f
;
float
f_beta
=
(
float
)
0.0
f
;
half
h_alpha
=
(
half
)(
f_alpha
);
half
h_beta
=
(
half
)(
f_beta
);
int
is_fp16_computeType
=
computeType
==
CUDA_R_16F
?
1
:
0
;
const
void
*
alpha
=
is_fp16_computeType
?
reinterpret_cast
<
void
*>
(
&
h_alpha
)
:
reinterpret_cast
<
void
*>
(
&
f_alpha
);
const
void
*
beta
=
is_fp16_computeType
?
reinterpret_cast
<
void
*>
(
&
h_beta
)
:
reinterpret_cast
<
void
*>
(
&
f_beta
);
printf
(
"***Encoder Gemm Testing Begin***
\n
"
);
printf
(
"***Encoder Gemm Testing Begin***
\n
"
);
printf
(
"***Cublas Gemm Testing Begin***
\n
"
);
printf
(
"***Cublas Gemm Testing Begin***
\n
"
);
...
@@ -342,7 +359,7 @@ void generate_gpt_gemm_config(int batch_size,
...
@@ -342,7 +359,7 @@ void generate_gpt_gemm_config(int batch_size,
max_input_len
,
max_input_len
,
max_input_len
,
max_input_len
,
size_per_head
,
size_per_head
,
&
alpha
,
&
f_
alpha
,
d_B
,
d_B
,
BType
,
BType
,
size_per_head
,
size_per_head
,
...
@@ -351,13 +368,13 @@ void generate_gpt_gemm_config(int batch_size,
...
@@ -351,13 +368,13 @@ void generate_gpt_gemm_config(int batch_size,
AType
,
AType
,
size_per_head
,
size_per_head
,
max_input_len
*
size_per_head
,
max_input_len
*
size_per_head
,
&
beta
,
&
f_
beta
,
d_C
,
d_C
,
CUDA_R_32F
,
// CType,
CUDA_R_32F
,
// CType,
max_input_len
,
max_input_len
,
max_input_len
*
max_input_len
,
max_input_len
*
max_input_len
,
batchCount
[
i
],
batchCount
[
i
],
computeType
,
CUDA_R_32F
,
static_cast
<
cublasGemmAlgo_t
>
(
algo
));
static_cast
<
cublasGemmAlgo_t
>
(
algo
));
}
}
else
if
(
i
==
2
)
{
else
if
(
i
==
2
)
{
...
@@ -456,44 +473,45 @@ void generate_gpt_gemm_config(int batch_size,
...
@@ -456,44 +473,45 @@ void generate_gpt_gemm_config(int batch_size,
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
// for gpt, computeType & scaleType should be FP32
// for gpt, computeType & scaleType should be FP32
LtHgemmCustomFind
<
T
,
float
>
(
ltHandle
,
// LtHgemmCustomFind<T, float>(ltHandle,
batch_size
*
beam_width
,
// batch_size * beam_width,
i
==
1
||
i
==
2
?
max_input_len
:
1
,
// i == 1 || i == 2 ? max_input_len : 1,
head_num
,
// head_num,
size_per_head
,
// size_per_head,
n
,
// n,
m
,
// m,
k
,
// k,
&
alpha
,
// &alpha,
d_B
,
// d_B,
d_A
,
// d_A,
&
beta
,
// &beta,
d_C
,
// d_C,
cublas_workspace
,
// cublas_workspace,
workSpaceSize
,
// workSpaceSize,
fd
,
// fd,
perfResults
,
// perfResults,
ALGO_COMBINATIONS
,
// ALGO_COMBINATIONS,
DType_FP8
[
i
],
// DType_FP8[i],
batchCount
[
i
],
// batchCount[i],
strideA
[
i
],
// strideA[i],
strideB
[
i
],
// strideB[i],
strideD
[
i
]);
// strideD[i]);
if
(
perfResults
[
0
].
time
<
exec_time
)
{
// if (perfResults[0].time < exec_time) {
printPerfStructure
(
batch_size
*
beam_width
,
// printPerfStructure(batch_size * beam_width,
seq_len
,
// seq_len,
head_num
,
// head_num,
size_per_head
,
// size_per_head,
n
,
// n,
m
,
// m,
k
,
// k,
perfResults
[
0
],
// perfResults[0],
fd
,
// fd,
data_type
,
// data_type,
0
,
// 0,
batchCount
[
i
]);
// batchCount[i]);
}
// }
else
{
// else {
{
fprintf
(
fd
,
fprintf
(
fd
,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
...
...
src/turbomind/utils/gemm_test/swin_gemm_func.cc
View file @
9484fd1c
...
@@ -133,8 +133,8 @@ void generate_swin_gemm_config(
...
@@ -133,8 +133,8 @@ void generate_swin_gemm_config(
cublasHandle_t
cublas_handle
;
cublasHandle_t
cublas_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
cublasLtHandle_t
ltHandle
;
//
cublasLtHandle_t ltHandle;
check_cuda_error
(
cublasLtCreate
(
&
ltHandle
));
//
check_cuda_error(cublasLtCreate(<Handle));
cudaDataType_t
AType
;
cudaDataType_t
AType
;
cudaDataType_t
BType
;
cudaDataType_t
BType
;
...
@@ -151,16 +151,19 @@ void generate_swin_gemm_config(
...
@@ -151,16 +151,19 @@ void generate_swin_gemm_config(
CType
=
CUDA_R_32F
;
CType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO23
;
// endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
data_type
=
HALF_DATATYPE
;
data_type
=
HALF_DATATYPE
;
AType
=
CUDA_R_16F
;
AType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_16F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
...
@@ -169,11 +172,14 @@ void generate_swin_gemm_config(
...
@@ -169,11 +172,14 @@ void generate_swin_gemm_config(
BType
=
CUDA_R_16BF
;
BType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#endif
#endif
using
scaleT
=
typename
ScaleTypeConverter
<
T
,
false
>::
Type
;
// using scaleT = typename ScaleTypeConverter<T, false>::Type;
using
scaleT
=
typename
ScaleTypeConverter
<
T
,
true
>::
Type
;
scaleT
alpha
=
(
scaleT
)
1.0
f
;
scaleT
alpha
=
(
scaleT
)
1.0
f
;
scaleT
beta
=
(
scaleT
)
0.0
f
;
scaleT
beta
=
(
scaleT
)
0.0
f
;
...
@@ -309,30 +315,31 @@ void generate_swin_gemm_config(
...
@@ -309,30 +315,31 @@ void generate_swin_gemm_config(
const
int
ALGO_COMBINATIONS
=
5000
;
const
int
ALGO_COMBINATIONS
=
5000
;
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
LtHgemmCustomFind
<
T
,
scaleT
>
(
ltHandle
,
// LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size
,
// batch_size,
seq_len
,
// seq_len,
head_num
,
// head_num,
size_per_head
,
// size_per_head,
n
,
// n,
m
,
// m,
k
,
// k,
&
alpha
,
// &alpha,
d_B
,
// d_B,
d_A
,
// d_A,
&
beta
,
// &beta,
d_C
,
// d_C,
cublas_workspace
,
// cublas_workspace,
workSpaceSize
,
// workSpaceSize,
fd
,
// fd,
perfResults
,
// perfResults,
ALGO_COMBINATIONS
);
// ALGO_COMBINATIONS);
if
(
perfResults
[
0
].
time
<
exec_time
)
{
// if (perfResults[0].time < exec_time) {
printPerfStructure
(
// printPerfStructure(
batch_size
,
seq_len
,
head_num
,
size_per_head
,
n
,
m
,
k
,
perfResults
[
0
],
fd
,
data_type
,
0
);
// batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time
=
perfResults
[
0
].
time
;
// exec_time = perfResults[0].time;
}
// }
else
{
// else {
{
fprintf
(
fd
,
fprintf
(
fd
,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
...
...
src/turbomind/utils/gemm_test/swin_igemm_func.cc
View file @
9484fd1c
...
@@ -144,23 +144,23 @@ int igemm_config_INT8IO(int m, int n, int k, FILE* fout, void* buffer)
...
@@ -144,23 +144,23 @@ int igemm_config_INT8IO(int m, int n, int k, FILE* fout, void* buffer)
int8_t
*
d_B
=
d_A
+
m
*
k
;
// k * n, stored in column-major
int8_t
*
d_B
=
d_A
+
m
*
k
;
// k * n, stored in column-major
int8_t
*
d_C
=
(
int8_t
*
)(
d_B
+
k
*
n
);
// m * n, stored in column-major
int8_t
*
d_C
=
(
int8_t
*
)(
d_B
+
k
*
n
);
// m * n, stored in column-major
cublasLtHandle_t
ltHandle
;
//
cublasLtHandle_t ltHandle;
cublasLtCreate
(
&
ltHandle
);
//
cublasLtCreate(<Handle);
LtIgemmCustomFind
(
ltHandle
,
//
LtIgemmCustomFind(ltHandle,
m
,
//
m,
n
,
//
n,
k
,
//
k,
&
alpha
,
/* host pointer */
//
&alpha, /* host pointer */
d_A
,
//
d_A,
d_B
,
//
d_B,
&
beta
,
/* host pointer */
//
&beta, /* host pointer */
d_C
,
//
d_C,
NULL
,
//
NULL,
0
,
//
0,
fout
);
//
fout);
cublasLtDestroy
(
ltHandle
);
//
cublasLtDestroy(ltHandle);
return
0
;
return
0
;
}
}
...
...
src/turbomind/utils/gemm_test/t5_gemm_func.cc
View file @
9484fd1c
...
@@ -195,8 +195,8 @@ void generate_t5_gemm_config(int batch_size,
...
@@ -195,8 +195,8 @@ void generate_t5_gemm_config(int batch_size,
cublasHandle_t
cublas_handle
;
cublasHandle_t
cublas_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
cublasLtHandle_t
ltHandle
;
//
cublasLtHandle_t ltHandle;
check_cuda_error
(
cublasLtCreate
(
&
ltHandle
));
//
check_cuda_error(cublasLtCreate(<Handle));
cudaDataType_t
AType
;
cudaDataType_t
AType
;
cudaDataType_t
BType
;
cudaDataType_t
BType
;
...
@@ -213,16 +213,19 @@ void generate_t5_gemm_config(int batch_size,
...
@@ -213,16 +213,19 @@ void generate_t5_gemm_config(int batch_size,
CType
=
CUDA_R_32F
;
CType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO23
;
// endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
data_type
=
HALF_DATATYPE
;
data_type
=
HALF_DATATYPE
;
AType
=
CUDA_R_16F
;
AType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_16F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
...
@@ -231,8 +234,10 @@ void generate_t5_gemm_config(int batch_size,
...
@@ -231,8 +234,10 @@ void generate_t5_gemm_config(int batch_size,
BType
=
CUDA_R_16BF
;
BType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#endif
#endif
float
f_alpha
=
(
float
)
1.0
f
;
float
f_alpha
=
(
float
)
1.0
f
;
...
@@ -442,60 +447,61 @@ void generate_t5_gemm_config(int batch_size,
...
@@ -442,60 +447,61 @@ void generate_t5_gemm_config(int batch_size,
scaleT
alpha_scale
=
(
scaleT
)
1.0
f
;
scaleT
alpha_scale
=
(
scaleT
)
1.0
f
;
scaleT
beta_scale
=
(
scaleT
)
0.0
f
;
scaleT
beta_scale
=
(
scaleT
)
0.0
f
;
LtHgemmCustomFind
<
T
,
scaleT
>
(
ltHandle
,
//
LtHgemmCustomFind<T, scaleT>(ltHandle,
m
,
//
m,
seq_len
,
//
seq_len,
head_num
,
//
head_num,
size_per_head
,
//
size_per_head,
n
,
//
n,
m
,
//
m,
k
,
//
k,
&
(
alpha_scale
),
//
&(alpha_scale),
d_B
,
//
d_B,
d_A
,
//
d_A,
&
(
beta_scale
),
//
&(beta_scale),
d_C
,
//
d_C,
cublas_workspace
,
//
cublas_workspace,
workSpaceSize
,
//
workSpaceSize,
fd
,
//
fd,
perfResults
,
//
perfResults,
ALGO_COMBINATIONS
);
//
ALGO_COMBINATIONS);
}
}
else
{
else
{
LtHgemmCustomFind
<
T
,
float
>
(
ltHandle
,
//
LtHgemmCustomFind<T, float>(ltHandle,
m
,
//
m,
seq_len
,
//
seq_len,
head_num
,
//
head_num,
size_per_head
,
//
size_per_head,
n
,
//
n,
m
,
//
m,
k
,
//
k,
&
(
f_alpha
),
//
&(f_alpha),
d_B
,
//
d_B,
d_A
,
//
d_A,
&
(
f_beta
),
//
&(f_beta),
d_C
,
//
d_C,
cublas_workspace
,
//
cublas_workspace,
workSpaceSize
,
//
workSpaceSize,
fd
,
//
fd,
perfResults
,
//
perfResults,
ALGO_COMBINATIONS
);
//
ALGO_COMBINATIONS);
}
}
if
(
perfResults
[
0
].
time
<
exec_time
)
{
// if (perfResults[0].time < exec_time) {
printPerfStructure
(
batch_size
*
(
i
<=
5
||
i
==
1
?
1
:
beam_width
),
// printPerfStructure(batch_size * (i <= 5 || i == 1 ? 1 : beam_width),
seq_len
,
// seq_len,
head_num
,
// head_num,
size_per_head
,
// size_per_head,
n
,
// n,
m
,
// m,
k
,
// k,
perfResults
[
0
],
// perfResults[0],
fd
,
// fd,
data_type
,
// data_type,
0
);
// 0);
}
// }
else
{
// else {
{
fprintf
(
fd
,
fprintf
(
fd
,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
...
...
src/turbomind/utils/gemm_test/xlnet_gemm_func.cc
View file @
9484fd1c
...
@@ -218,8 +218,8 @@ void generate_xlnet_gemm_config(int batch_size,
...
@@ -218,8 +218,8 @@ void generate_xlnet_gemm_config(int batch_size,
cublasHandle_t
cublas_handle
;
cublasHandle_t
cublas_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
cublasLtHandle_t
ltHandle
;
//
cublasLtHandle_t ltHandle;
check_cuda_error
(
cublasLtCreate
(
&
ltHandle
));
//
check_cuda_error(cublasLtCreate(<Handle));
cudaDataType_t
AType
;
cudaDataType_t
AType
;
cudaDataType_t
BType
;
cudaDataType_t
BType
;
...
@@ -236,16 +236,19 @@ void generate_xlnet_gemm_config(int batch_size,
...
@@ -236,16 +236,19 @@ void generate_xlnet_gemm_config(int batch_size,
CType
=
CUDA_R_32F
;
CType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO23
;
// endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
data_type
=
HALF_DATATYPE
;
data_type
=
HALF_DATATYPE
;
AType
=
CUDA_R_16F
;
AType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
BType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_16F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
...
@@ -254,12 +257,15 @@ void generate_xlnet_gemm_config(int batch_size,
...
@@ -254,12 +257,15 @@ void generate_xlnet_gemm_config(int batch_size,
BType
=
CUDA_R_16BF
;
BType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
computeType
=
CUDA_R_32F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
}
#endif
#endif
using
scaleT
=
typename
ScaleTypeConverter
<
T
,
false
>::
Type
;
// using scaleT = typename ScaleTypeConverter<T, false>::Type;
using
scaleT
=
typename
ScaleTypeConverter
<
T
,
true
>::
Type
;
scaleT
alpha
=
(
scaleT
)
1.0
f
;
scaleT
alpha
=
(
scaleT
)
1.0
f
;
scaleT
beta
=
(
scaleT
)
0.0
f
;
scaleT
beta
=
(
scaleT
)
0.0
f
;
...
@@ -358,30 +364,31 @@ void generate_xlnet_gemm_config(int batch_size,
...
@@ -358,30 +364,31 @@ void generate_xlnet_gemm_config(int batch_size,
const
int
ALGO_COMBINATIONS
=
5000
;
const
int
ALGO_COMBINATIONS
=
5000
;
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
LtHgemmCustomFind
<
T
,
scaleT
>
(
ltHandle
,
// LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size
,
// batch_size,
seq_len
,
// seq_len,
head_num
,
// head_num,
size_per_head
,
// size_per_head,
n
,
// n,
m
,
// m,
k
,
// k,
&
alpha
,
// &alpha,
d_B
,
// d_B,
d_A
,
// d_A,
&
beta
,
// &beta,
d_C
,
// d_C,
cublas_workspace
,
// cublas_workspace,
workSpaceSize
,
// workSpaceSize,
fd
,
// fd,
perfResults
,
// perfResults,
ALGO_COMBINATIONS
);
// ALGO_COMBINATIONS);
if
(
perfResults
[
0
].
time
<
exec_time
)
{
// if (perfResults[0].time < exec_time) {
printPerfStructure
(
// printPerfStructure(
batch_size
,
seq_len
,
head_num
,
size_per_head
,
n
,
m
,
k
,
perfResults
[
0
],
fd
,
data_type
,
0
);
// batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time
=
perfResults
[
0
].
time
;
// exec_time = perfResults[0].time;
}
// }
else
{
// else {
{
fprintf
(
fd
,
fprintf
(
fd
,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment