Unverified Commit 1f88baa5 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

update log info (#131)

* update log info

* format cuda utils
parent db3b986b
...@@ -647,11 +647,11 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa, ...@@ -647,11 +647,11 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa,
void* C) void* C)
{ {
if (Atype_ != CUDA_R_16F || Btype_ != CUDA_R_16F || Ctype_ != CUDA_R_16F) { if (Atype_ != CUDA_R_16F || Btype_ != CUDA_R_16F || Ctype_ != CUDA_R_16F) {
throw std::runtime_error("\n[FT][ERROR] sparse GEMM only supports FP16 data type now."); throw std::runtime_error("\n[TM][ERROR] sparse GEMM only supports FP16 data type now.");
} }
static bool not_printed_fp32_accumulation_warning = true; static bool not_printed_fp32_accumulation_warning = true;
if (computeType_ != CUDA_R_16F && not_printed_fp32_accumulation_warning) { if (computeType_ != CUDA_R_16F && not_printed_fp32_accumulation_warning) {
printf("[FT][WARNING] cublasMMWrapper sets to FP32 compute type, " printf("[TM][WARNING] cublasMMWrapper sets to FP32 compute type, "
"but sparse gemm will use FP16 compute type since cusparselt " "but sparse gemm will use FP16 compute type since cusparselt "
"supports FP16 accumulation only.\n"); "supports FP16 accumulation only.\n");
not_printed_fp32_accumulation_warning = false; not_printed_fp32_accumulation_warning = false;
...@@ -803,13 +803,13 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand ...@@ -803,13 +803,13 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); FT_CHECK_WITH_INFO(false, "CUBLAS version too low.");
return {false, cublasLtMatmulAlgo_t{}}; return {false, cublasLtMatmulAlgo_t{}};
#else #else
size_t returnSize; size_t returnSize;
int32_t pointer_mode; int32_t pointer_mode;
cublasLtMatmulDescGetAttribute( cublasLtMatmulDescGetAttribute(
computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize); computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize);
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200); std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
cublasLtMatmulPreference_t preference; cublasLtMatmulPreference_t preference;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
check_cuda_error(cublasLtMatmulPreferenceInit(preference)); check_cuda_error(cublasLtMatmulPreferenceInit(preference));
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
...@@ -821,8 +821,8 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand ...@@ -821,8 +821,8 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask))); preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif #endif
int return_count = 0; int return_count = 0;
auto ret = cublasLtMatmulAlgoGetHeuristic(lightHandle, auto ret = cublasLtMatmulAlgoGetHeuristic(lightHandle,
computeDesc, computeDesc,
Adesc, Adesc,
Bdesc, Bdesc,
...@@ -837,7 +837,7 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand ...@@ -837,7 +837,7 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
std::map<int, std::vector<float>> algo_results; std::map<int, std::vector<float>> algo_results;
for (const auto& heuristic : heuristics) { for (const auto& heuristic : heuristics) {
cublasLtMatmulAlgo_t algo = heuristic.algo; cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id; int32_t algo_id;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize); cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
cudaEvent_t start_event, stop_event; cudaEvent_t start_event, stop_event;
...@@ -845,7 +845,7 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand ...@@ -845,7 +845,7 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
cudaEventCreate(&stop_event); cudaEventCreate(&stop_event);
float my_alpha = 1.0f; float my_alpha = 1.0f;
float my_beta = 0.0f; float my_beta = 0.0f;
for (int i = 0; i < 11; i++) { for (int i = 0; i < 11; i++) {
float duration_ms; float duration_ms;
...@@ -876,16 +876,16 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand ...@@ -876,16 +876,16 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
} }
cublasLtMatmulHeuristicResult_t result; cublasLtMatmulHeuristicResult_t result;
float best_time = INFINITY; float best_time = INFINITY;
for (const auto& heuristic : heuristics) { for (const auto& heuristic : heuristics) {
cublasLtMatmulAlgo_t algo = heuristic.algo; cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id; int32_t algo_id;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize); cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
const auto& results = algo_results[algo_id]; const auto& results = algo_results[algo_id];
if (results.size() > 0 && results[5] < best_time) { if (results.size() > 0 && results[5] < best_time) {
best_time = results[5]; best_time = results[5];
result = heuristic; result = heuristic;
} }
} }
...@@ -989,20 +989,20 @@ void cublasMMWrapper::_Int8Gemm(const int m, ...@@ -989,20 +989,20 @@ void cublasMMWrapper::_Int8Gemm(const int m,
#else #else
mu_->lock(); mu_->lock();
const auto op_a = CUBLAS_OP_T; const auto op_a = CUBLAS_OP_T;
const auto op_b = CUBLAS_OP_N; const auto op_b = CUBLAS_OP_N;
const auto dataType = CUDA_R_8I; const auto dataType = CUDA_R_8I;
const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I; const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I;
const auto computeType = CUBLAS_COMPUTE_32I; const auto computeType = CUBLAS_COMPUTE_32I;
const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I; const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I;
const int batch_count = 1; const int batch_count = 1;
const void* beta; const void* beta;
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(dataType)); int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(dataType));
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(dataType)); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(dataType));
cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
// -------------------------------------- // --------------------------------------
...@@ -1027,20 +1027,20 @@ void cublasMMWrapper::_Int8Gemm(const int m, ...@@ -1027,20 +1027,20 @@ void cublasMMWrapper::_Int8Gemm(const int m,
check_cuda_error(cublasLtMatmulDescSetAttribute( check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
const int32_t int_one = 1; const int32_t int_one = 1;
const int32_t int_zero = 0; const int32_t int_zero = 0;
const float float_zero = 0; const float float_zero = 0;
if (mode == 0) { if (mode == 0) {
beta = per_column_scaling ? &float_zero : NULL; beta = per_column_scaling ? &float_zero : NULL;
} }
else { else {
alpha = &int_one; alpha = &int_one;
beta = &int_zero; beta = &int_zero;
} }
cublasLtMatmulAlgo_t algo; cublasLtMatmulAlgo_t algo;
void* workSpace = cublas_workspace_; void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
sync_check_cuda_error(); sync_check_cuda_error();
auto ret = cublasLtMatmulWrapper(cublaslt_handle_, auto ret = cublasLtMatmulWrapper(cublaslt_handle_,
......
...@@ -38,7 +38,7 @@ void print_to_file(const T* result, const int size, const char* file, cudaStream ...@@ -38,7 +38,7 @@ void print_to_file(const T* result, const int size, const char* file, cudaStream
delete[] tmp; delete[] tmp;
} }
else { else {
throw std::runtime_error(std::string("[FT][ERROR] Cannot open file: ") + file + "\n"); throw std::runtime_error(std::string("[TM][ERROR] Cannot open file: ") + file + "\n");
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError()); check_cuda_error(cudaGetLastError());
...@@ -81,7 +81,7 @@ void print_abs_mean(const T* buf, uint size, cudaStream_t stream, std::string na ...@@ -81,7 +81,7 @@ void print_abs_mean(const T* buf, uint size, cudaStream_t stream, std::string na
} }
max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i])); max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i]));
} }
printf("[INFO][FT] %20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s", printf("[TM][INFO] %20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s",
name.c_str(), name.c_str(),
size, size,
sum / size, sum / size,
......
...@@ -119,7 +119,7 @@ template<typename T> ...@@ -119,7 +119,7 @@ template<typename T>
void check(T result, char const* const func, const char* const file, int const line) void check(T result, char const* const func, const char* const file, int const line)
{ {
if (result) { if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " "
+ file + ":" + std::to_string(line) + " \n"); + file + ":" + std::to_string(line) + " \n");
} }
} }
...@@ -137,7 +137,7 @@ inline void syncAndCheck(const char* const file, int const line) ...@@ -137,7 +137,7 @@ inline void syncAndCheck(const char* const file, int const line)
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError(); cudaError_t result = cudaGetLastError();
if (result) { if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result))
+ " " + file + ":" + std::to_string(line) + " \n"); + " " + file + ":" + std::to_string(line) + " \n");
} }
TM_LOG_DEBUG(fmtstr("run syncAndCheck at %s:%d", file, line)); TM_LOG_DEBUG(fmtstr("run syncAndCheck at %s:%d", file, line));
...@@ -148,7 +148,7 @@ inline void syncAndCheck(const char* const file, int const line) ...@@ -148,7 +148,7 @@ inline void syncAndCheck(const char* const file, int const line)
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError(); cudaError_t result = cudaGetLastError();
if (result) { if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " "
+ file + ":" + std::to_string(line) + " \n"); + file + ":" + std::to_string(line) + " \n");
} }
#endif #endif
...@@ -194,12 +194,12 @@ void check_abs_mean_val(const T* result, const int size); ...@@ -194,12 +194,12 @@ void check_abs_mean_val(const T* result, const int size);
#define PRINT_FUNC_NAME_() \ #define PRINT_FUNC_NAME_() \
do { \ do { \
std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ std::cout << "[TM][CALL] " << __FUNCTION__ << " " << std::endl; \
} while (0) } while (0)
[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "") [[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
{ {
throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":" throw std::runtime_error(std::string("[TM][ERROR] ") + info + " Assertion fail: " + file + ":"
+ std::to_string(line) + " \n"); + std::to_string(line) + " \n");
} }
...@@ -226,7 +226,7 @@ inline void myAssert(bool result, const char* const file, int const line, std::s ...@@ -226,7 +226,7 @@ inline void myAssert(bool result, const char* const file, int const line, std::s
{ \ { \
cusparseStatus_t status = (func); \ cusparseStatus_t status = (func); \
if (status != CUSPARSE_STATUS_SUCCESS) { \ if (status != CUSPARSE_STATUS_SUCCESS) { \
throw std::runtime_error(std::string("[FT][ERROR] CUSPARSE API failed at line ") \ throw std::runtime_error(std::string("[TM][ERROR] CUSPARSE API failed at line ") \
+ std::to_string(__LINE__) + " in file " + __FILE__ + ": " \ + std::to_string(__LINE__) + " in file " + __FILE__ + ": " \
+ cusparseGetErrorString(status) + " " + std::to_string(status)); \ + cusparseGetErrorString(status) + " " + std::to_string(status)); \
} \ } \
......
...@@ -83,7 +83,7 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE ...@@ -83,7 +83,7 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else #else
stages = 0; stages = 0;
#endif #endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
...@@ -149,7 +149,7 @@ int printBatchPerfStructure( ...@@ -149,7 +149,7 @@ int printBatchPerfStructure(
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else #else
stages = 0; stages = 0;
#endif #endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
...@@ -352,7 +352,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, ...@@ -352,7 +352,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} }
#else #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif #endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
...@@ -369,7 +369,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle, ...@@ -369,7 +369,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else #else
status = cublasLtMatmulDescCreate(&operationDesc, scaleType); status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif #endif
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
...@@ -689,7 +689,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, ...@@ -689,7 +689,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} }
#else #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif #endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
...@@ -711,7 +711,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, ...@@ -711,7 +711,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
#if (CUDART_VERSION >= 11000) #if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else #else
status = cublasLtMatmulDescCreate(&operationDesc, scaleType); status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif #endif
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; goto CLEANUP;
...@@ -1252,7 +1252,7 @@ int generate_encoder_igemm_config( ...@@ -1252,7 +1252,7 @@ int generate_encoder_igemm_config(
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError(); cudaError_t result = cudaGetLastError();
if (result) { if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ")); throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: "));
} }
float exec_time = 99999.0f; float exec_time = 99999.0f;
......
...@@ -47,7 +47,7 @@ Logger::Logger() ...@@ -47,7 +47,7 @@ Logger::Logger()
} }
else { else {
fprintf(stderr, fprintf(stderr,
"[FT][WARNING] Invalid logger level TM_LOG_LEVEL=%s. " "[TM][WARNING] Invalid logger level TM_LOG_LEVEL=%s. "
"Ignore the environment variable and use a default " "Ignore the environment variable and use a default "
"logging level.\n", "logging level.\n",
level_name); level_name);
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tests/unittests/gtest_utils.h"
#include "src/turbomind/kernels/gen_relative_pos_bias.h" #include "src/turbomind/kernels/gen_relative_pos_bias.h"
#include "src/turbomind/kernels/gpt_kernels.h" #include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h" #include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/memory_utils.h" #include "src/turbomind/utils/memory_utils.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
#include "src/turbomind/utils/Tensor.h" #include "tests/unittests/gtest_utils.h"
#include <curand.h> #include <curand.h>
#include <sstream> #include <sstream>
...@@ -42,17 +42,18 @@ struct AttentionKernelTestParam { ...@@ -42,17 +42,18 @@ struct AttentionKernelTestParam {
size_t rotary_embedding_dim = 0; size_t rotary_embedding_dim = 0;
bool neox_rotary_style = false; bool neox_rotary_style = false;
float q_scaling = 1.0f; float q_scaling = 1.0f;
}; };
namespace utils { namespace utils {
#define CHECK_CURAND(cmd) do { \ #define CHECK_CURAND(cmd) \
curandStatus_t err = cmd; \ do { \
if (err != CURAND_STATUS_SUCCESS) { \ curandStatus_t err = cmd; \
throw std::runtime_error( \ if (err != CURAND_STATUS_SUCCESS) { \
fmtstr("[FT][ERROR] curand runtime error: %d", err)); \ throw std::runtime_error(fmtstr("[TM][ERROR] curand runtime error: %d", err)); \
}} while(0) \ } \
} while (0)
__global__ void convert_and_copy(half* dst, const float* src, const size_t size) __global__ void convert_and_copy(half* dst, const float* src, const size_t size)
{ {
...@@ -177,7 +178,7 @@ void computeQkSoftmax(T* attn_score, ...@@ -177,7 +178,7 @@ void computeQkSoftmax(T* attn_score,
for (size_t ki = 0; ki < k_length; ++ki) { for (size_t ki = 0; ki < k_length; ++ki) {
size_t qk_idx = qi * k_length + ki; size_t qk_idx = qi * k_length + ki;
if (int(mask[qk_idx]) > 0) { // mask = 0 or 1. if (int(mask[qk_idx]) > 0) { // mask = 0 or 1.
float val = (float)safe_add_bias(::math::mul(qk_scale, qk[qk_idx]), head_pos_bias, qk_idx); float val = (float)safe_add_bias(::math::mul(qk_scale, qk[qk_idx]), head_pos_bias, qk_idx);
attn_score[qk_idx] = static_cast<T>(expf(val - maxval) / (sum + EPSILON)); attn_score[qk_idx] = static_cast<T>(expf(val - maxval) / (sum + EPSILON));
} }
else { else {
...@@ -188,12 +189,12 @@ void computeQkSoftmax(T* attn_score, ...@@ -188,12 +189,12 @@ void computeQkSoftmax(T* attn_score,
// Move the data pointers to the next. // Move the data pointers to the next.
attn_score += q_length * k_length; attn_score += q_length * k_length;
qk += q_length * k_length; qk += q_length * k_length;
} }
} }
template<typename T> template<typename T>
class AttentionKernelTest : public FtTestBase { class AttentionKernelTest: public FtTestBase {
private: private:
using FtTestBase::stream; using FtTestBase::stream;
...@@ -252,10 +253,11 @@ public: ...@@ -252,10 +253,11 @@ public:
FtTestBase::TearDown(); FtTestBase::TearDown();
} }
void runTestMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false) { void runTestMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false)
{
DataType dtype = getTensorType<T>(); DataType dtype = getTensorType<T>();
std::vector<size_t> qk_shape {param.batch_size, param.head_num, param.q_length, param.k_length}; std::vector<size_t> qk_shape{param.batch_size, param.head_num, param.q_length, param.k_length};
bool use_fp32_qk = param.use_fp32_qk_buf && dtype != TYPE_FP32; bool use_fp32_qk = param.use_fp32_qk_buf && dtype != TYPE_FP32;
...@@ -279,27 +281,27 @@ public: ...@@ -279,27 +281,27 @@ public:
if (param.use_fp32_qk_buf && dtype != TYPE_FP32) { if (param.use_fp32_qk_buf && dtype != TYPE_FP32) {
MaskedSoftmaxParam<T, float> softmax_param; MaskedSoftmaxParam<T, float> softmax_param;
softmax_param.attention_score = qk.getPtr<T>(); softmax_param.attention_score = qk.getPtr<T>();
softmax_param.qk = qk_fp32.getPtr<float>(); softmax_param.qk = qk_fp32.getPtr<float>();
softmax_param.attention_mask = attn_mask.getPtr<T>(); softmax_param.attention_mask = attn_mask.getPtr<T>();
softmax_param.batch_size = param.batch_size; softmax_param.batch_size = param.batch_size;
softmax_param.num_heads = param.head_num; softmax_param.num_heads = param.head_num;
softmax_param.q_length = param.q_length; softmax_param.q_length = param.q_length;
softmax_param.k_length = param.k_length; softmax_param.k_length = param.k_length;
softmax_param.qk_scale = scale; softmax_param.qk_scale = scale;
invokeMaskedSoftmax(softmax_param, stream); invokeMaskedSoftmax(softmax_param, stream);
sync_check_cuda_error(); sync_check_cuda_error();
} }
else { else {
MaskedSoftmaxParam<T, T> softmax_param; MaskedSoftmaxParam<T, T> softmax_param;
softmax_param.attention_score = qk.getPtr<T>(); softmax_param.attention_score = qk.getPtr<T>();
softmax_param.qk = qk.getPtr<T>(); softmax_param.qk = qk.getPtr<T>();
softmax_param.attention_mask = attn_mask.getPtr<T>(); softmax_param.attention_mask = attn_mask.getPtr<T>();
softmax_param.batch_size = param.batch_size; softmax_param.batch_size = param.batch_size;
softmax_param.num_heads = param.head_num; softmax_param.num_heads = param.head_num;
softmax_param.q_length = param.q_length; softmax_param.q_length = param.q_length;
softmax_param.k_length = param.k_length; softmax_param.k_length = param.k_length;
softmax_param.qk_scale = scale; softmax_param.qk_scale = scale;
invokeMaskedSoftmax(softmax_param, stream); invokeMaskedSoftmax(softmax_param, stream);
sync_check_cuda_error(); sync_check_cuda_error();
} }
...@@ -332,10 +334,11 @@ public: ...@@ -332,10 +334,11 @@ public:
} }
} }
void runTestAlibiMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false) { void runTestAlibiMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false)
{
DataType dtype = getTensorType<T>(); DataType dtype = getTensorType<T>();
std::vector<size_t> qk_shape {param.batch_size, param.head_num, param.q_length, param.k_length}; std::vector<size_t> qk_shape{param.batch_size, param.head_num, param.q_length, param.k_length};
bool use_fp32_qk = param.use_fp32_qk_buf && dtype != TYPE_FP32; bool use_fp32_qk = param.use_fp32_qk_buf && dtype != TYPE_FP32;
...@@ -355,16 +358,17 @@ public: ...@@ -355,16 +358,17 @@ public:
sync_check_cuda_error(); sync_check_cuda_error();
Tensor h_alibi_slopes = createTensor(MEMORY_CPU, dtype, {param.head_num}); Tensor h_alibi_slopes = createTensor(MEMORY_CPU, dtype, {param.head_num});
Tensor h_alibi_bias = is_benchmark ? Tensor() : Tensor h_alibi_bias =
createTensor(MEMORY_CPU, dtype, {param.head_num, param.q_length, param.k_length}); is_benchmark ? Tensor() : createTensor(MEMORY_CPU, dtype, {param.head_num, param.q_length, param.k_length});
// The nearest power of 2 equal to / smaller than num_heads followed by HF's implementation. // The nearest power of 2 equal to / smaller than num_heads followed by HF's implementation.
T* alibi_slope_ptr = h_alibi_slopes.getPtr<T>(); T* alibi_slope_ptr = h_alibi_slopes.getPtr<T>();
int num_heads_pow2 = utils::pow2_rounddown(param.head_num); int num_heads_pow2 = utils::pow2_rounddown(param.head_num);
for (size_t h = 0; h < param.head_num; ++h) { for (size_t h = 0; h < param.head_num; ++h) {
// The slope of linear bias of the attention head // The slope of linear bias of the attention head
if (h < num_heads_pow2) { if (h < num_heads_pow2) {
alibi_slope_ptr[h] = static_cast<T>(powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2) - 3.f)), h + 1)); alibi_slope_ptr[h] = static_cast<T>(powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2) - 3.f)), h + 1));
} else { }
else {
alibi_slope_ptr[h] = static_cast<T>( alibi_slope_ptr[h] = static_cast<T>(
powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2 << 1) - 3.f)), (h - num_heads_pow2) * 2 + 1)); powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2 << 1) - 3.f)), (h - num_heads_pow2) * 2 + 1));
} }
...@@ -372,7 +376,7 @@ public: ...@@ -372,7 +376,7 @@ public:
T* alibi_bias_ptr = h_alibi_bias.getPtr<T>(); T* alibi_bias_ptr = h_alibi_bias.getPtr<T>();
for (size_t qi = 0; qi < param.q_length; ++qi) { for (size_t qi = 0; qi < param.q_length; ++qi) {
for (size_t ki = 0; ki < param.k_length; ++ki) { for (size_t ki = 0; ki < param.k_length; ++ki) {
size_t hqk_idx = (h * param.q_length + qi) * param.k_length + ki; size_t hqk_idx = (h * param.q_length + qi) * param.k_length + ki;
alibi_bias_ptr[hqk_idx] = ::math::mul(alibi_slope_ptr[h], T(0.0f + ki - qi)); alibi_bias_ptr[hqk_idx] = ::math::mul(alibi_slope_ptr[h], T(0.0f + ki - qi));
} }
} }
...@@ -448,87 +452,106 @@ public: ...@@ -448,87 +452,106 @@ public:
TYPED_TEST_SUITE(AttentionKernelTest, SupportTypes); TYPED_TEST_SUITE(AttentionKernelTest, SupportTypes);
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt) { TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt)
{
this->runTestMaskedSoftmax({1, 12, 12, 1, 32, false, 0, false}); this->runTestMaskedSoftmax({1, 12, 12, 1, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt2) { TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt2)
{
// q_length is not multiple of 4. // q_length is not multiple of 4.
this->runTestMaskedSoftmax({1, 11, 11, 4, 32, false, 0, false}); this->runTestMaskedSoftmax({1, 11, 11, 4, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt) { TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt)
{
this->runTestMaskedSoftmax({1, 12, 24, 2, 32, false, 0, false}); this->runTestMaskedSoftmax({1, 12, 24, 2, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt2) { TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt2)
{
this->runTestMaskedSoftmax({1, 11, 24, 2, 32, false, 0, false}); this->runTestMaskedSoftmax({1, 11, 24, 2, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence1024) { TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence1024)
{
this->runTestMaskedSoftmax({1, 12, 1024, 2, 32, false, 0, false}); this->runTestMaskedSoftmax({1, 12, 1024, 2, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence2048) { TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence2048)
{
this->runTestMaskedSoftmax({1, 12, 2048, 2, 32, false, 0, false}); this->runTestMaskedSoftmax({1, 12, 2048, 2, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence3072) { TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence3072)
{
this->runTestMaskedSoftmax({1, 12, 3072, 2, 32, false, 0, false}); this->runTestMaskedSoftmax({1, 12, 3072, 2, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence4096) { TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence4096)
{
this->runTestMaskedSoftmax({1, 12, 4096, 2, 32, false, 0, false}); this->runTestMaskedSoftmax({1, 12, 4096, 2, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence1024) { TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence1024)
{
// Assume the bloom 176B model with 8 TP. // Assume the bloom 176B model with 8 TP.
this->runTestMaskedSoftmax({8, 1024, 1024, 14, 128, false, 0, false, true}, true); this->runTestMaskedSoftmax({8, 1024, 1024, 14, 128, false, 0, false, true}, true);
} }
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence2048) { TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence2048)
{
// Assume the bloom 176B model with 8 TP. // Assume the bloom 176B model with 8 TP.
this->runTestMaskedSoftmax({8, 2048, 2048, 14, 128, false, 0, false, true}, true); this->runTestMaskedSoftmax({8, 2048, 2048, 14, 128, false, 0, false, true}, true);
} }
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence4096) { TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence4096)
{
// Assume the bloom 176B model with 8 TP. // Assume the bloom 176B model with 8 TP.
this->runTestMaskedSoftmax({8, 4096, 4096, 14, 128, false, 0, false, true}, true); this->runTestMaskedSoftmax({8, 4096, 4096, 14, 128, false, 0, false, true}, true);
} }
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence1) { TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence1)
{
this->runTestAlibiMaskedSoftmax({1, 12, 12, 4, 32, false, 0, false}); this->runTestAlibiMaskedSoftmax({1, 12, 12, 4, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence2) { TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence2)
{
// q_length is not multiple of 4. // q_length is not multiple of 4.
this->runTestAlibiMaskedSoftmax({1, 11, 11, 4, 32, false, 0, false}); this->runTestAlibiMaskedSoftmax({1, 11, 11, 4, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt1) { TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt1)
{
this->runTestAlibiMaskedSoftmax({1, 12, 20, 4, 32, false, 0, false}); this->runTestAlibiMaskedSoftmax({1, 12, 20, 4, 32, false, 0, false});
} }
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt2) { TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt2)
{
// q_length is not multiple of 4. // q_length is not multiple of 4.
this->runTestAlibiMaskedSoftmax({1, 11, 20, 4, 32, false, 0, false}); this->runTestAlibiMaskedSoftmax({1, 11, 20, 4, 32, false, 0, false});
} }
// Tests for long sentence generation. Assume the bloom 176B model with 8 TP. // Tests for long sentence generation. Assume the bloom 176B model with 8 TP.
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence1024) { TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence1024)
{
this->runTestAlibiMaskedSoftmax({8, 1024, 1024, 14, 128, false, 0, false, true}, true); this->runTestAlibiMaskedSoftmax({8, 1024, 1024, 14, 128, false, 0, false, true}, true);
} }
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence2048) { TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence2048)
{
this->runTestAlibiMaskedSoftmax({8, 2048, 2048, 14, 128, false, 0, false, true}, true); this->runTestAlibiMaskedSoftmax({8, 2048, 2048, 14, 128, false, 0, false, true}, true);
} }
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence3072) { TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence3072)
{
this->runTestAlibiMaskedSoftmax({8, 3072, 3072, 14, 128, false, 0, false, true}, true); this->runTestAlibiMaskedSoftmax({8, 3072, 3072, 14, 128, false, 0, false, true}, true);
} }
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence4096) { TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence4096)
{
this->runTestAlibiMaskedSoftmax({4, 4096, 4096, 14, 128, false, 0, false, true}, true); this->runTestAlibiMaskedSoftmax({4, 4096, 4096, 14, 128, false, 0, false, true}, true);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment