Unverified Commit e357c71f authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

add gemm tuning (#18)

parent 93604c3f
...@@ -37,4 +37,7 @@ target_link_libraries(Llama PUBLIC -lcudart ...@@ -37,4 +37,7 @@ target_link_libraries(Llama PUBLIC -lcudart
nccl_utils nccl_utils
cuda_utils cuda_utils
logger logger
llama_fmha) llama_fmha)
\ No newline at end of file
add_executable(llama_gemm llama_gemm.cc)
target_link_libraries(llama_gemm PUBLIC -lcudart gpt_gemm_func memory_utils cuda_utils logger)
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Copied from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/gpt_gemm.cc
#include "src/fastertransformer/utils/gemm_test/gpt_gemm_func.h"
#include "src/fastertransformer/utils/memory_utils.h"
namespace ft = fastertransformer;
int main(int argc, char* argv[])
{
if (argc < 9 || argc > 11) {
FT_LOG_ERROR("./bin/llama_gemm batch_size \\ \n"
" beam_width \\ \n"
" max_input_len \\ \n"
" head_number \\ \n"
" size_per_head \\ \n"
" inter_size \\ \n"
" vocab_size \\ \n"
" data_type \\ \n"
" tensor_para_size \\\n"
" is_append (append new config into exist gemm_config.ini or not)");
FT_LOG_ERROR("e.g. ./bin/llama_gemm 8 4 32 96 128 49152 51200 1 8 1");
return 0;
}
const int batch_size = atoi(argv[1]);
const int beam_width = atoi(argv[2]);
const int max_input_len = atoi(argv[3]);
const int head_num = atoi(argv[4]);
const int size_per_head = atoi(argv[5]);
const int inter_size = atoi(argv[6]);
const int vocab_size = atoi(argv[7]);
const ft::CublasDataType data_type = static_cast<ft::CublasDataType>(atoi(argv[8])); // 0 FP32, 1 FP16, 2 BF 16
const int tensor_para_size = argc < 10 ? 1 : atoi(argv[9]);
const bool is_append = argc < 11 ? false : (bool)(atoi(argv[10]));
FT_LOG_INFO("Arguments:");
FT_LOG_INFO(" batch_size: %d", batch_size);
FT_LOG_INFO(" beam_width: %d", beam_width);
FT_LOG_INFO(" max_input_len: %d", max_input_len);
FT_LOG_INFO(" head_num: %d", head_num);
FT_LOG_INFO(" size_per_head: %d", size_per_head);
FT_LOG_INFO(" inter_size: %d", inter_size);
FT_LOG_INFO(" vocab_size: %d", vocab_size);
FT_LOG_INFO(" data_type: %d", data_type);
FT_LOG_INFO(" tensor_para_size: %d", tensor_para_size);
FT_LOG_INFO(" is_append: %d", (int)is_append);
std::cout << std::endl;
void* gemm_test_buf;
size_t buf_size_in_byte = ft::calGptGemmTestBufSizeInByte(batch_size,
beam_width,
max_input_len,
head_num,
size_per_head,
inter_size,
vocab_size,
tensor_para_size,
data_type);
size_t total, free;
ft::check_cuda_error(cudaMemGetInfo(&free, &total));
if (free < buf_size_in_byte + 10 * 1024 * 1024) {
printf("[ERROR] There is no enough device memory for gemm test!\n"
" %ld Bytes is needed, but only %ld Bytes is free.\n",
buf_size_in_byte,
free);
gemm_test_buf = NULL;
return -1;
}
else {
ft::deviceMalloc(reinterpret_cast<char**>(&gemm_test_buf), buf_size_in_byte, false);
}
if (data_type == ft::FLOAT_DATATYPE) {
ft::generate_gpt_gemm_config<float>(batch_size,
beam_width,
max_input_len,
head_num,
size_per_head,
inter_size,
vocab_size,
tensor_para_size,
gemm_test_buf,
is_append);
}
else if (data_type == ft::HALF_DATATYPE) {
ft::generate_gpt_gemm_config<half>(batch_size,
beam_width,
max_input_len,
head_num,
size_per_head,
inter_size,
vocab_size,
tensor_para_size,
gemm_test_buf,
is_append);
}
#ifdef ENABLE_BF16
else if (data_type == ft::BFLOAT16_DATATYPE) {
ft::generate_gpt_gemm_config<__nv_bfloat16>(batch_size,
beam_width,
max_input_len,
head_num,
size_per_head,
inter_size,
vocab_size,
tensor_para_size,
gemm_test_buf,
is_append);
}
#endif
#ifdef ENABLE_FP8
else if (data_type == ft::FP8_DATATYPE) {
ft::generate_gpt_gemm_config<__nv_fp8_e4m3>(batch_size,
beam_width,
max_input_len,
head_num,
size_per_head,
inter_size,
vocab_size,
tensor_para_size,
gemm_test_buf,
false);
}
#endif
else {
printf("[ERROR] data type only supports fp32(0), fp16(1), bf16(2), fp8(4). \n");
return -1;
}
ft::check_cuda_error(cudaFree(gemm_test_buf));
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment