#include "gemv_bf16.h" int main(int argc, char **argv) { int warmups = 100; int loops = 2000; bool do_verify = false; float alpha = 1.0f; float beta = 0.0f; int M = 11264; // int N = 1; // Unused int K = 4096; int lda = K; int block_size = 128; if (char *value = getCmdOption(argv, argv + argc, "--warmups")) { warmups = std::stoi(value); } if (char *value = getCmdOption(argv, argv + argc, "--loops")) { loops = std::stoi(value); } if (char *value = getCmdOption(argv, argv + argc, "--verify")) { do_verify = std::stoi(value) == 1; } if (char *value = getCmdOption(argv, argv + argc, "--alpha")) { alpha = std::stof(value); } if (char *value = getCmdOption(argv, argv + argc, "--beta")) { beta = std::stof(value); } if (char *value = getCmdOption(argv, argv + argc, "-M")) { M = std::stoi(value); } if (char *value = getCmdOption(argv, argv + argc, "-K")) { K = std::stoi(value); lda = K; } if (char *value = getCmdOption(argv, argv + argc, "--lda")) { lda = std::stoi(value); } if (char *value = getCmdOption(argv, argv + argc, "-B")) { block_size = std::stoi(value); } // transA=T,因此是行优先 size_t count_A = (size_t)M * lda; size_t size_A = count_A * sizeof(hip_bfloat16); size_t size_x = (size_t)K * sizeof(hip_bfloat16); size_t size_y = (size_t)M * sizeof(hip_bfloat16); // Host 内存分配 std::vector h_A(count_A); std::vector h_x(K); std::vector h_y(M); // 随机初始数据 const float rand_max = static_cast(RAND_MAX); for (int i = 0; i < count_A; i++) h_A[i] = hip_bfloat16(static_cast(rand()) / rand_max); for (int i = 0; i < K; i++) h_x[i] = hip_bfloat16(static_cast(rand()) / rand_max); for (int i = 0; i < M; i++) h_y[i] = hip_bfloat16(0.0f); // Device 内存分配 hip_bfloat16 *d_A, *d_x, *d_y; checkHipErrors(hipMalloc(&d_A, size_A)); checkHipErrors(hipMalloc(&d_x, size_x)); checkHipErrors(hipMalloc(&d_y, size_y)); checkHipErrors(hipMemcpy(d_A, h_A.data(), size_A, hipMemcpyHostToDevice)); checkHipErrors(hipMemcpy(d_x, h_x.data(), size_x, hipMemcpyHostToDevice)); checkHipErrors(hipMemcpy(d_y, h_y.data(), size_y, hipMemcpyHostToDevice)); // Kernel 注册表 std::vector kernels; constexpr bool NTL = true; constexpr int UNROLL = 4; constexpr int ROWS_PER_WARP = 2; #if defined(__HIP_PLATFORM_AMD__) constexpr int TILE_K = calculate_tile_k<8>(8); #else constexpr int TILE_K = calculate_tile_k<8>(1); #endif kernels.push_back( {"naive", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int grid = (M + block_size - 1) / block_size; gemv_bf16_TN_naive<<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int grid = (M + block_size - 1) / block_size; gemv_bf16_TN_vec<<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int grid = (M + block_size - 1) / block_size; gemv_bf16_TN_vec <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_warp<<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp<<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp_mr" + std::to_string(ROWS_PER_WARP), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_mr <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_mr <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp+unroll" + std::to_string(UNROLL), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_unroll <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp+unroll" + std::to_string(UNROLL), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_unroll <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp+unroll" + std::to_string(UNROLL) + "+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_unroll_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp+unroll" + std::to_string(UNROLL) + "+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_unroll_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_mr_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_mr_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); // 打印信息 printf("GEMV Benchmarks: y = alpha * A^T * x + beta * y\n"); printf("Block size=%d\n", block_size); printf("alpha=%.2f, beta=%.2f\n", alpha, beta); printf("M=%d, N=1, K=%d, lda=%d\n", M, K, lda); printf("sizeof(A)=%lu MB\n", M * lda * sizeof(hip_bfloat16) / 1024 / 1024); printf("L2 cache=%d MB\n", get_l2_cache_size()); printf("Warmups=%d, Loops=%d\n", warmups, loops); // 运行所有测试 run_benchmark(warmups, loops, kernels, M, K, alpha, d_A, lda, d_x, beta, d_y, do_verify); // 清理 checkHipErrors(hipFree(d_A)); checkHipErrors(hipFree(d_x)); checkHipErrors(hipFree(d_y)); return 0; }