Unverified Commit 4d42a781 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

fix-gemm-tuning (#24)

parent e357c71f
...@@ -4,15 +4,15 @@ import subprocess ...@@ -4,15 +4,15 @@ import subprocess
import fire import fire
def main(head_num: int = 80, def main(head_num: int = 32,
size_per_head: int = 128, size_per_head: int = 128,
vocab_size: int = 65632, vocab_size: int = 32000,
inter_size: int = 27392, inter_size: int = 11008,
tensor_para_size: int = 8, tensor_para_size: int = 1,
max_batch_size: int = 64): max_batch_size: int = 64):
for bsz in range(1, max_batch_size + 1): for bsz in range(1, max_batch_size + 1):
subprocess.call( subprocess.call(
f'bin/gpt_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size} {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}', f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size} {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}',
shell=True) shell=True)
......
...@@ -270,10 +270,10 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, ...@@ -270,10 +270,10 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
// Let try a fixed number of combinations // Let try a fixed number of combinations
int AlgoCount = 0; int AlgoCount = 0;
int AlgoCountRestrict = 0; // workspace == 0 int AlgoCountRestrict = 0; // workspace == 0
int maxNumTraversal = 50; // max number of traversal const int maxNumTraversal = 50; // max number of traversal
cublasLtMatmulAlgo_t algos[AlgoCombinations]; // 0 <= workspace <= 32MB cublasLtMatmulAlgo_t algos[AlgoCombinations]; // 0 <= workspace <= 32MB
cublasLtMatmulAlgo_t algosRestrict[AlgoCombinations]; // workspace == 0 cublasLtMatmulAlgo_t algosRestrict[AlgoCombinations]; // workspace == 0
int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back const int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
int nbAlgoIds = 0; // Number of algorithms actually returned by int nbAlgoIds = 0; // Number of algorithms actually returned by
// cublasLtMatmulAlgoGetIds function. // cublasLtMatmulAlgoGetIds function.
#define ALGO_IDS 100 // Number of algorithms requested. #define ALGO_IDS 100 // Number of algorithms requested.
......
...@@ -39,6 +39,7 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -39,6 +39,7 @@ void generate_gpt_gemm_config(int batch_size,
void* cublas_workspace; void* cublas_workspace;
void* buffer; void* buffer;
int workSpaceSize; int workSpaceSize;
#if 0
bool workspace_flag = std::is_same<T, half>::value; bool workspace_flag = std::is_same<T, half>::value;
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
workspace_flag = workspace_flag || std::is_same<T, __nv_fp8_e4m3>::value; workspace_flag = workspace_flag || std::is_same<T, __nv_fp8_e4m3>::value;
...@@ -46,6 +47,9 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -46,6 +47,9 @@ void generate_gpt_gemm_config(int batch_size,
#if ENABLE_BF16 #if ENABLE_BF16
workspace_flag = workspace_flag || std::is_same<T, __nv_bfloat16>::value; workspace_flag = workspace_flag || std::is_same<T, __nv_bfloat16>::value;
#endif #endif
#endif
// algorithms with workspace perform worse than evaluated
const bool workspace_flag = 0;
if (workspace_flag) { if (workspace_flag) {
// cublas_workspace_ should be the start pointer of cudaMalloc() // cublas_workspace_ should be the start pointer of cudaMalloc()
// to ensure 16B alignemnet // to ensure 16B alignemnet
...@@ -310,7 +314,8 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -310,7 +314,8 @@ void generate_gpt_gemm_config(int batch_size,
} }
for (int i = 0; i < gemm_num; ++i) { for (int i = 0; i < gemm_num; ++i) {
if (i <= 5) { // tuning of context gemm and logits gemm is not working yet
if (i <= 5 || i == 10) {
continue; continue;
} }
int seq_len = i <= 5 ? max_input_len : 1; int seq_len = i <= 5 ? max_input_len : 1;
...@@ -445,7 +450,7 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -445,7 +450,7 @@ void generate_gpt_gemm_config(int batch_size,
if ((data_type != FLOAT_DATATYPE && i != 1 && i != 2 && i != 10) || data_type == FP8_DATATYPE) { if ((data_type != FLOAT_DATATYPE && i != 1 && i != 2 && i != 10) || data_type == FP8_DATATYPE) {
printf("***cublasLt Gemm Testing Beign***\n"); printf("***cublasLt Gemm Testing Beign***\n");
// Let try a fixed number of combinations // Let try a fixed number of combinations
int ALGO_COMBINATIONS = 5000; int ALGO_COMBINATIONS = 10000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
// for gpt, computeType & scaleType should be FP32 // for gpt, computeType & scaleType should be FP32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment