Unverified Commit 1f88baa5 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

update log info (#131)

* update log info

* format cuda utils
parent db3b986b
......@@ -647,11 +647,11 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa,
void* C)
{
if (Atype_ != CUDA_R_16F || Btype_ != CUDA_R_16F || Ctype_ != CUDA_R_16F) {
throw std::runtime_error("\n[FT][ERROR] sparse GEMM only supports FP16 data type now.");
throw std::runtime_error("\n[TM][ERROR] sparse GEMM only supports FP16 data type now.");
}
static bool not_printed_fp32_accumulation_warning = true;
if (computeType_ != CUDA_R_16F && not_printed_fp32_accumulation_warning) {
printf("[FT][WARNING] cublasMMWrapper sets to FP32 compute type, "
printf("[TM][WARNING] cublasMMWrapper sets to FP32 compute type, "
"but sparse gemm will use FP16 compute type since cusparselt "
"supports FP16 accumulation only.\n");
not_printed_fp32_accumulation_warning = false;
......@@ -803,13 +803,13 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
FT_CHECK_WITH_INFO(false, "CUBLAS version too low.");
return {false, cublasLtMatmulAlgo_t{}};
#else
size_t returnSize;
size_t returnSize;
int32_t pointer_mode;
cublasLtMatmulDescGetAttribute(
computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize);
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
cublasLtMatmulPreference_t preference;
cublasLtMatmulPreference_t preference;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
......@@ -821,8 +821,8 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif
int return_count = 0;
auto ret = cublasLtMatmulAlgoGetHeuristic(lightHandle,
int return_count = 0;
auto ret = cublasLtMatmulAlgoGetHeuristic(lightHandle,
computeDesc,
Adesc,
Bdesc,
......@@ -837,7 +837,7 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
std::map<int, std::vector<float>> algo_results;
for (const auto& heuristic : heuristics) {
cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id;
int32_t algo_id;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
cudaEvent_t start_event, stop_event;
......@@ -845,7 +845,7 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
cudaEventCreate(&stop_event);
float my_alpha = 1.0f;
float my_beta = 0.0f;
float my_beta = 0.0f;
for (int i = 0; i < 11; i++) {
float duration_ms;
......@@ -876,16 +876,16 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
}
cublasLtMatmulHeuristicResult_t result;
float best_time = INFINITY;
float best_time = INFINITY;
for (const auto& heuristic : heuristics) {
cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id;
int32_t algo_id;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
const auto& results = algo_results[algo_id];
if (results.size() > 0 && results[5] < best_time) {
best_time = results[5];
result = heuristic;
result = heuristic;
}
}
......@@ -989,20 +989,20 @@ void cublasMMWrapper::_Int8Gemm(const int m,
#else
mu_->lock();
const auto op_a = CUBLAS_OP_T;
const auto op_b = CUBLAS_OP_N;
const auto dataType = CUDA_R_8I;
const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I;
const auto computeType = CUBLAS_COMPUTE_32I;
const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I;
const int batch_count = 1;
const auto op_a = CUBLAS_OP_T;
const auto op_b = CUBLAS_OP_N;
const auto dataType = CUDA_R_8I;
const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I;
const auto computeType = CUBLAS_COMPUTE_32I;
const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I;
const int batch_count = 1;
const void* beta;
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(dataType));
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(dataType));
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
// --------------------------------------
......@@ -1027,20 +1027,20 @@ void cublasMMWrapper::_Int8Gemm(const int m,
check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
const int32_t int_one = 1;
const int32_t int_zero = 0;
const float float_zero = 0;
const int32_t int_one = 1;
const int32_t int_zero = 0;
const float float_zero = 0;
if (mode == 0) {
beta = per_column_scaling ? &float_zero : NULL;
}
else {
alpha = &int_one;
beta = &int_zero;
beta = &int_zero;
}
cublasLtMatmulAlgo_t algo;
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
sync_check_cuda_error();
auto ret = cublasLtMatmulWrapper(cublaslt_handle_,
......
......@@ -38,7 +38,7 @@ void print_to_file(const T* result, const int size, const char* file, cudaStream
delete[] tmp;
}
else {
throw std::runtime_error(std::string("[FT][ERROR] Cannot open file: ") + file + "\n");
throw std::runtime_error(std::string("[TM][ERROR] Cannot open file: ") + file + "\n");
}
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
......@@ -81,7 +81,7 @@ void print_abs_mean(const T* buf, uint size, cudaStream_t stream, std::string na
}
max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i]));
}
printf("[INFO][FT] %20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s",
printf("[TM][INFO] %20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s",
name.c_str(),
size,
sum / size,
......
......@@ -119,7 +119,7 @@ template<typename T>
void check(T result, char const* const func, const char* const file, int const line)
{
if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " "
throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " "
+ file + ":" + std::to_string(line) + " \n");
}
}
......@@ -137,7 +137,7 @@ inline void syncAndCheck(const char* const file, int const line)
cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError();
if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result))
throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result))
+ " " + file + ":" + std::to_string(line) + " \n");
}
TM_LOG_DEBUG(fmtstr("run syncAndCheck at %s:%d", file, line));
......@@ -148,7 +148,7 @@ inline void syncAndCheck(const char* const file, int const line)
cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError();
if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " "
throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " "
+ file + ":" + std::to_string(line) + " \n");
}
#endif
......@@ -194,12 +194,12 @@ void check_abs_mean_val(const T* result, const int size);
#define PRINT_FUNC_NAME_() \
do { \
std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \
std::cout << "[TM][CALL] " << __FUNCTION__ << " " << std::endl; \
} while (0)
[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
{
throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":"
throw std::runtime_error(std::string("[TM][ERROR] ") + info + " Assertion fail: " + file + ":"
+ std::to_string(line) + " \n");
}
......@@ -226,7 +226,7 @@ inline void myAssert(bool result, const char* const file, int const line, std::s
{ \
cusparseStatus_t status = (func); \
if (status != CUSPARSE_STATUS_SUCCESS) { \
throw std::runtime_error(std::string("[FT][ERROR] CUSPARSE API failed at line ") \
throw std::runtime_error(std::string("[TM][ERROR] CUSPARSE API failed at line ") \
+ std::to_string(__LINE__) + " in file " + __FILE__ + ": " \
+ cusparseGetErrorString(status) + " " + std::to_string(status)); \
} \
......
......@@ -83,7 +83,7 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else
stages = 0;
stages = 0;
#endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
......@@ -149,7 +149,7 @@ int printBatchPerfStructure(
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else
stages = 0;
stages = 0;
#endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
......@@ -352,7 +352,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
}
#else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif
int ldaTransform = 32 * m;
......@@ -369,7 +369,7 @@ int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
#if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
......@@ -689,7 +689,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
}
#else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif
int ldaTransform = 32 * m;
......@@ -711,7 +711,7 @@ int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
#if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
......@@ -1252,7 +1252,7 @@ int generate_encoder_igemm_config(
cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError();
if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: "));
throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: "));
}
float exec_time = 99999.0f;
......
......@@ -47,7 +47,7 @@ Logger::Logger()
}
else {
fprintf(stderr,
"[FT][WARNING] Invalid logger level TM_LOG_LEVEL=%s. "
"[TM][WARNING] Invalid logger level TM_LOG_LEVEL=%s. "
"Ignore the environment variable and use a default "
"logging level.\n",
level_name);
......
......@@ -14,13 +14,13 @@
* limitations under the License.
*/
#include "tests/unittests/gtest_utils.h"
#include "src/turbomind/kernels/gen_relative_pos_bias.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/memory_utils.h"
#include "src/turbomind/utils/nccl_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "tests/unittests/gtest_utils.h"
#include <curand.h>
#include <sstream>
......@@ -42,17 +42,18 @@ struct AttentionKernelTestParam {
size_t rotary_embedding_dim = 0;
bool neox_rotary_style = false;
float q_scaling = 1.0f;
float q_scaling = 1.0f;
};
namespace utils {
#define CHECK_CURAND(cmd) do { \
curandStatus_t err = cmd; \
if (err != CURAND_STATUS_SUCCESS) { \
throw std::runtime_error( \
fmtstr("[FT][ERROR] curand runtime error: %d", err)); \
}} while(0) \
#define CHECK_CURAND(cmd) \
do { \
curandStatus_t err = cmd; \
if (err != CURAND_STATUS_SUCCESS) { \
throw std::runtime_error(fmtstr("[TM][ERROR] curand runtime error: %d", err)); \
} \
} while (0)
__global__ void convert_and_copy(half* dst, const float* src, const size_t size)
{
......@@ -177,7 +178,7 @@ void computeQkSoftmax(T* attn_score,
for (size_t ki = 0; ki < k_length; ++ki) {
size_t qk_idx = qi * k_length + ki;
if (int(mask[qk_idx]) > 0) { // mask = 0 or 1.
float val = (float)safe_add_bias(::math::mul(qk_scale, qk[qk_idx]), head_pos_bias, qk_idx);
float val = (float)safe_add_bias(::math::mul(qk_scale, qk[qk_idx]), head_pos_bias, qk_idx);
attn_score[qk_idx] = static_cast<T>(expf(val - maxval) / (sum + EPSILON));
}
else {
......@@ -188,12 +189,12 @@ void computeQkSoftmax(T* attn_score,
// Move the data pointers to the next.
attn_score += q_length * k_length;
qk += q_length * k_length;
qk += q_length * k_length;
}
}
template<typename T>
class AttentionKernelTest : public FtTestBase {
class AttentionKernelTest: public FtTestBase {
private:
using FtTestBase::stream;
......@@ -252,10 +253,11 @@ public:
FtTestBase::TearDown();
}
void runTestMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false) {
void runTestMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false)
{
DataType dtype = getTensorType<T>();
std::vector<size_t> qk_shape {param.batch_size, param.head_num, param.q_length, param.k_length};
std::vector<size_t> qk_shape{param.batch_size, param.head_num, param.q_length, param.k_length};
bool use_fp32_qk = param.use_fp32_qk_buf && dtype != TYPE_FP32;
......@@ -279,27 +281,27 @@ public:
if (param.use_fp32_qk_buf && dtype != TYPE_FP32) {
MaskedSoftmaxParam<T, float> softmax_param;
softmax_param.attention_score = qk.getPtr<T>();
softmax_param.qk = qk_fp32.getPtr<float>();
softmax_param.attention_mask = attn_mask.getPtr<T>();
softmax_param.batch_size = param.batch_size;
softmax_param.num_heads = param.head_num;
softmax_param.q_length = param.q_length;
softmax_param.k_length = param.k_length;
softmax_param.qk_scale = scale;
softmax_param.attention_score = qk.getPtr<T>();
softmax_param.qk = qk_fp32.getPtr<float>();
softmax_param.attention_mask = attn_mask.getPtr<T>();
softmax_param.batch_size = param.batch_size;
softmax_param.num_heads = param.head_num;
softmax_param.q_length = param.q_length;
softmax_param.k_length = param.k_length;
softmax_param.qk_scale = scale;
invokeMaskedSoftmax(softmax_param, stream);
sync_check_cuda_error();
}
else {
MaskedSoftmaxParam<T, T> softmax_param;
softmax_param.attention_score = qk.getPtr<T>();
softmax_param.qk = qk.getPtr<T>();
softmax_param.attention_mask = attn_mask.getPtr<T>();
softmax_param.batch_size = param.batch_size;
softmax_param.num_heads = param.head_num;
softmax_param.q_length = param.q_length;
softmax_param.k_length = param.k_length;
softmax_param.qk_scale = scale;
softmax_param.attention_score = qk.getPtr<T>();
softmax_param.qk = qk.getPtr<T>();
softmax_param.attention_mask = attn_mask.getPtr<T>();
softmax_param.batch_size = param.batch_size;
softmax_param.num_heads = param.head_num;
softmax_param.q_length = param.q_length;
softmax_param.k_length = param.k_length;
softmax_param.qk_scale = scale;
invokeMaskedSoftmax(softmax_param, stream);
sync_check_cuda_error();
}
......@@ -332,10 +334,11 @@ public:
}
}
void runTestAlibiMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false) {
void runTestAlibiMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false)
{
DataType dtype = getTensorType<T>();
std::vector<size_t> qk_shape {param.batch_size, param.head_num, param.q_length, param.k_length};
std::vector<size_t> qk_shape{param.batch_size, param.head_num, param.q_length, param.k_length};
bool use_fp32_qk = param.use_fp32_qk_buf && dtype != TYPE_FP32;
......@@ -355,16 +358,17 @@ public:
sync_check_cuda_error();
Tensor h_alibi_slopes = createTensor(MEMORY_CPU, dtype, {param.head_num});
Tensor h_alibi_bias = is_benchmark ? Tensor() :
createTensor(MEMORY_CPU, dtype, {param.head_num, param.q_length, param.k_length});
Tensor h_alibi_bias =
is_benchmark ? Tensor() : createTensor(MEMORY_CPU, dtype, {param.head_num, param.q_length, param.k_length});
// The nearest power of 2 equal to / smaller than num_heads followed by HF's implementation.
T* alibi_slope_ptr = h_alibi_slopes.getPtr<T>();
int num_heads_pow2 = utils::pow2_rounddown(param.head_num);
T* alibi_slope_ptr = h_alibi_slopes.getPtr<T>();
int num_heads_pow2 = utils::pow2_rounddown(param.head_num);
for (size_t h = 0; h < param.head_num; ++h) {
// The slope of linear bias of the attention head
if (h < num_heads_pow2) {
alibi_slope_ptr[h] = static_cast<T>(powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2) - 3.f)), h + 1));
} else {
}
else {
alibi_slope_ptr[h] = static_cast<T>(
powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2 << 1) - 3.f)), (h - num_heads_pow2) * 2 + 1));
}
......@@ -372,7 +376,7 @@ public:
T* alibi_bias_ptr = h_alibi_bias.getPtr<T>();
for (size_t qi = 0; qi < param.q_length; ++qi) {
for (size_t ki = 0; ki < param.k_length; ++ki) {
size_t hqk_idx = (h * param.q_length + qi) * param.k_length + ki;
size_t hqk_idx = (h * param.q_length + qi) * param.k_length + ki;
alibi_bias_ptr[hqk_idx] = ::math::mul(alibi_slope_ptr[h], T(0.0f + ki - qi));
}
}
......@@ -448,87 +452,106 @@ public:
TYPED_TEST_SUITE(AttentionKernelTest, SupportTypes);
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt) {
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt)
{
this->runTestMaskedSoftmax({1, 12, 12, 1, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt2) {
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt2)
{
// q_length is not multiple of 4.
this->runTestMaskedSoftmax({1, 11, 11, 4, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt) {
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt)
{
this->runTestMaskedSoftmax({1, 12, 24, 2, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt2) {
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt2)
{
this->runTestMaskedSoftmax({1, 11, 24, 2, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence1024) {
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence1024)
{
this->runTestMaskedSoftmax({1, 12, 1024, 2, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence2048) {
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence2048)
{
this->runTestMaskedSoftmax({1, 12, 2048, 2, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence3072) {
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence3072)
{
this->runTestMaskedSoftmax({1, 12, 3072, 2, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence4096) {
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence4096)
{
this->runTestMaskedSoftmax({1, 12, 4096, 2, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence1024) {
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence1024)
{
// Assume the bloom 176B model with 8 TP.
this->runTestMaskedSoftmax({8, 1024, 1024, 14, 128, false, 0, false, true}, true);
}
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence2048) {
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence2048)
{
// Assume the bloom 176B model with 8 TP.
this->runTestMaskedSoftmax({8, 2048, 2048, 14, 128, false, 0, false, true}, true);
}
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence4096) {
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence4096)
{
// Assume the bloom 176B model with 8 TP.
this->runTestMaskedSoftmax({8, 4096, 4096, 14, 128, false, 0, false, true}, true);
}
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence1) {
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence1)
{
this->runTestAlibiMaskedSoftmax({1, 12, 12, 4, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence2) {
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence2)
{
// q_length is not multiple of 4.
this->runTestAlibiMaskedSoftmax({1, 11, 11, 4, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt1) {
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt1)
{
this->runTestAlibiMaskedSoftmax({1, 12, 20, 4, 32, false, 0, false});
}
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt2) {
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt2)
{
// q_length is not multiple of 4.
this->runTestAlibiMaskedSoftmax({1, 11, 20, 4, 32, false, 0, false});
}
// Tests for long sentence generation. Assume the bloom 176B model with 8 TP.
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence1024) {
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence1024)
{
this->runTestAlibiMaskedSoftmax({8, 1024, 1024, 14, 128, false, 0, false, true}, true);
}
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence2048) {
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence2048)
{
this->runTestAlibiMaskedSoftmax({8, 2048, 2048, 14, 128, false, 0, false, true}, true);
}
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence3072) {
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence3072)
{
this->runTestAlibiMaskedSoftmax({8, 3072, 3072, 14, 128, false, 0, false, true}, true);
}
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence4096) {
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence4096)
{
this->runTestAlibiMaskedSoftmax({4, 4096, 4096, 14, 128, false, 0, false, true}, true);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment