#pragma once #include #include #include #include #include #include #include #include // -------------------------------------------------------------------------------- // Error Handling // -------------------------------------------------------------------------------- inline void checkHipErrors(hipError_t result) { if (result != hipSuccess) { std::cerr << "HIP Error: " << hipGetErrorString(result) << std::endl; exit(1); } } // -------------------------------------------------------------------------------- // Command Line Parsing // -------------------------------------------------------------------------------- inline char *getCmdOption(char **begin, char **end, const std::string &option) { char **itr = std::find(begin, end, option); if (itr != end && ++itr != end) { return *itr; } return 0; } // -------------------------------------------------------------------------------- // CPU Reference & Verification // -------------------------------------------------------------------------------- inline void gemv_cpu(int M, int K, float alpha, const hip_bfloat16 *h_A, int lda, const hip_bfloat16 *h_x, float beta, hip_bfloat16 *h_y) { for (int m = 0; m < M; ++m) { float sum = 0.0f; for (int k = 0; k < K; ++k) { float val_a = static_cast(h_A[m * lda + k]); float val_x = static_cast(h_x[k]); sum += val_a * val_x; } h_y[m] = hip_bfloat16(alpha * sum + beta * h_y[m]); } return; } inline bool verify_result(int M, const hip_bfloat16 *h_y_gpu, const hip_bfloat16 *h_y_ref) { float max_diff = 0.0f; int err_count = 0; for (int i = 0; i < M; ++i) { float gpu_val = static_cast(h_y_gpu[i]); float cpu_val = static_cast(h_y_ref[i]); float diff = std::abs(gpu_val - cpu_val); // bfloat16 的精度有限,容忍度需要稍大一点 // 同时也考虑数值大小,这里使用简单的绝对误差 + 相对误差阈值 if (diff > 0.1f && diff / (std::abs(cpu_val) + 1e-6) > 0.01) { if (err_count < 5) { std::cerr << "Mismatch at index " << i << ": GPU=" << gpu_val << ", CPU=" << cpu_val << ", Diff=" << diff << std::endl; } err_count++; } max_diff = std::max(max_diff, diff); } if (err_count > 0) { std::cerr << "Total mismatches: " << err_count << ", Max Diff: " << max_diff << std::endl; return false; } return true; } // -------------------------------------------------------------------------------- // Benchmark Framework // -------------------------------------------------------------------------------- // 定义统一的 Kernel Launcher 签名 using KernelLauncher = std::function; struct KernelCase { std::string name; KernelLauncher func; }; inline void run_benchmark(const std::vector &cases, int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y, bool do_verify) { std::cout << "GEMV Benchmarks" << std::endl; hipEvent_t start, stop; checkHipErrors(hipEventCreate(&start)); checkHipErrors(hipEventCreate(&stop)); // 准备 verification 数据 std::vector h_y_ref(M); std::vector h_y_gpu(M); std::vector h_A(M * K); std::vector h_x(K); if (do_verify) { std::cout << "Verifying results against CPU reference..." << std::endl; // 把 A 和 x 拷回 host 端计算 checkHipErrors(hipMemcpy(h_A.data(), A, M * K * sizeof(hip_bfloat16), hipMemcpyDeviceToHost)); checkHipErrors(hipMemcpy(h_x.data(), x, K * sizeof(hip_bfloat16), hipMemcpyDeviceToHost)); // 计算 CPU Reference gemv_cpu(M, K, alpha, h_A.data(), lda, h_x.data(), beta, h_y_ref.data()); } // 列宽 const int w_table = 80; // 表头 printf("%s\n", std::string(w_table, '-').c_str()); printf("M=%d, K=%d, N=1\n", M, K); printf("lda=%d\n", lda); printf("%s\n", std::string(w_table, '-').c_str()); printf("%-38s %10s %10s %10s %8s\n", "Kernel Name", "Time (us)", "GFLOPS", "BW (GB/s)", "Result"); for (const auto &k : cases) { std::string result_status = "Skipped"; // 1. Verification (如果启用) if (do_verify) { // 清零 d_y checkHipErrors(hipMemset(y, 0, M * sizeof(hip_bfloat16))); // 运行一次 k.func(M, K, alpha, A, lda, x, beta, y); checkHipErrors(hipDeviceSynchronize()); // 拷回结果 checkHipErrors(hipMemcpy(h_y_gpu.data(), y, M * sizeof(hip_bfloat16), hipMemcpyDeviceToHost)); if (verify_result(M, h_y_gpu.data(), h_y_ref.data())) { result_status = "PASS"; } else { result_status = "FAIL"; } } // 2. Warmup for (int i = 0; i < 100; ++i) { k.func(M, K, alpha, A, lda, x, beta, y); } checkHipErrors(hipDeviceSynchronize()); // 3. Timing int num_runs = 1000; checkHipErrors(hipEventRecord(start)); for (int i = 0; i < num_runs; ++i) { k.func(M, K, alpha, A, lda, x, beta, y); } checkHipErrors(hipEventRecord(stop)); checkHipErrors(hipEventSynchronize(stop)); float total_ms = 0; checkHipErrors(hipEventElapsedTime(&total_ms, start, stop)); float avg_ms = total_ms / num_runs; // 4. Metrics double gflops = (2.0 * M * K) / (avg_ms * 1e-3) / 1e9; // Bandwidth = Read A + Read x + Write y // A: M*K, x: K, y: M double bytes_moved = (double)(M * K + K + M) * sizeof(hip_bfloat16); double bw = bytes_moved / (avg_ms * 1e-3) / 1e9; printf("%-38s %10.1f %10.2f %10.2f %8s\n", k.name.c_str(), avg_ms * 1e3, gflops, bw, result_status.c_str()); } std::cout << std::string(w_table, '-') << std::endl; checkHipErrors(hipEventDestroy(start)); checkHipErrors(hipEventDestroy(stop)); return; }