/** * 模仿 GEMM 接口的 GEMV,即 N=1。 * 编译: * hipcc -std=c++17 -O3 --offload-arch=gfx936 gemv_bf16.hip -o gemv_bench * 执行: * HIP_VISIBLE_DEVICES=4 numactl -N 0 -m 0 ./gemv_bench -M 11264 -K 4096 */ #include "gemv_utils.h" #define WARP_SIZE 64 #define VEC_WIDTH 8 #define OFFSET(i, j, lda) ((i) + (j) * (lda)) #define OFFSET_T(i, j, lda) ((i) * (lda) + (j)) /** * 根据需求的并发 block 数量计算 shmem 用量(即 TILE_K 指定的 BF16 元素个数) * * AlignElements 为对齐粒度,即元素个数,默认 128-bit 对齐。 * - 8: 对齐到 128-bit (可能有利于 load128b) * - 16: 对齐到 256-bit (某些 MFMA 指令需求) */ template constexpr int calculate_tile_k(int concurrent_blocks) { // 安全检查 if (concurrent_blocks < 1) concurrent_blocks = 1; // 直接切分 LDS constexpr int MAX_LDS_BYTES_PER_CU = 65536; int bytes_per_block = MAX_LDS_BYTES_PER_CU / concurrent_blocks; // 转为元素个数 int max_elements = bytes_per_block / sizeof(hip_bfloat16); // 对齐 return (max_elements / AlignElements) * AlignElements; } /// 辅助结构体:把 float4 (128位) 重新解释为 8 个 bf16 struct __align__(16) bf16_x8 { hip_bfloat16 vals[VEC_WIDTH]; }; /// 替代 float4 typedef float __attribute__((ext_vector_type(4))) float4_native; /// 128-bit non-temporal load 或者 cached load template __device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) { if constexpr (USE_NTL) { // 把地址转换为 float4_native 指针 const float4_native *ptr = reinterpret_cast(src); // 使用 Clang 内置 non-temporal load 函数,生成带有 slc/nt 修饰符的加载指令 float4_native tmp = __builtin_nontemporal_load(ptr); // 把加载到的 128 位数据重新解释为 bf16_x8 return *reinterpret_cast(&tmp); } else { return *reinterpret_cast(src); } } /** y = alpha * A^T * x + 0 * y * Naive 实现: * - JKI * - 每个线程算一个输出,即 I 循环的一次迭代 */ __global__ void gemv_bf16_TN_naive(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int m = blockIdx.x * blockDim.x + threadIdx.x; // output if (m >= M) return; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; for (int k = 0; k < K; k++) { float val_a = static_cast(row_ptr[k]); float val_x = static_cast(x[k]); sum += val_a * val_x; } y[m] = hip_bfloat16(alpha * sum); return; } /** y = alpha * A^T * x + 0 * y * 向量化实现: * - JKI * - 每个线程算一个输出,即 I 循环的一次迭代。 * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 */ template __global__ void gemv_bf16_TN_vec(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int m = blockIdx.x * blockDim.x + threadIdx.x; // output if (m >= M) return; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; // 每次读 VEC_WIDTH 个数据 for (int k = 0; k < K; k += VEC_WIDTH) { bf16_x8 a_vec = load_128b(&row_ptr[k]); bf16_x8 x_vec = *reinterpret_cast(&x[k]); #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } y[m] = hip_bfloat16(alpha * sum); return; } /** y = alpha * A^T * x + 0 * y * Warp 归约: * - JKI * - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。 * - Warp 内归约。 */ __global__ void gemv_bf16_TN_warp(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; if (m >= M) return; const int stride = WARP_SIZE; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; for (int k = lane_id; k < K; k += stride) { float val_a = static_cast(row_ptr[k]); float val_x = static_cast(x[k]); sum += val_a * val_x; } #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 负责写回 if (lane_id == 0) { y[m] = hip_bfloat16(alpha * sum); } return; } /** y = alpha * A^T * x + 0 * y * Vec + warp: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 算一个输出,warp 内归约。 */ template __global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; if (m >= M) return; const int stride = WARP_SIZE * VEC_WIDTH; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; for (int k = lane_id * VEC_WIDTH; k < K; k += stride) { bf16_x8 a_vec = load_128b(&row_ptr[k]); bf16_x8 x_vec = *reinterpret_cast(&x[k]); #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 负责写回 if (lane_id == 0) { y[m] = hip_bfloat16(alpha * sum); } return; } /** y = alpha * A^T * x + 0 * y * 单线程 vec + warp 处理多行: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。 * - 每个 lane 维护 ROWS_PER_WARP 个累加器。 */ template __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; // 每个 warp 处理 ROWS_PER_WARP 行 int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP + warp_id * ROWS_PER_WARP; // 每个 lane 维护 ROWS_PER_WARP 个累加器 float sum[ROWS_PER_WARP] = {0.0f}; // 预先计算每一行的指针 const hip_bfloat16 *row_ptr[ROWS_PER_WARP]; #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { int m = m_base + r; // 越界时指向 A,确保地址有效,消除后续分支 row_ptr[r] = (m < M) ? (A + m * lda) : A; } const int stride = WARP_SIZE * VEC_WIDTH; for (int k = lane_id * VEC_WIDTH; k < K; k += stride) { bf16_x8 x_vec = *reinterpret_cast(&x[k]); bf16_x8 a_vecs[ROWS_PER_WARP]; // 批量加载,无分支 #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { a_vecs[r] = load_128b(&row_ptr[r][k]); } // 批量计算 #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum[r] += static_cast(a_vecs[r].vals[i]) * static_cast(x_vec.vals[i]); } } } // Warp 内归约(每行独立归约) #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum[r] += __shfl_down(sum[r], offset); } // Lane 0 写回结果 if (lane_id == 0) { int m = m_base + r; if (m < M) { y[m] = hip_bfloat16(alpha * sum[r]); } } } return; } /** y = alpha * A^T * x + 0 * y * 单线程 vec + warp + 主循环 unroll: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 算一个输出,warp 内归约。 * - 主循环 unrolling。 */ template __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; if (m >= M) return; const int stride = WARP_SIZE * VEC_WIDTH * UNROLL; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; // 主循环临时变量 bf16_x8 a_frag[UNROLL]; bf16_x8 x_frag[UNROLL]; int k0 = lane_id * VEC_WIDTH; int k = 0; // 主循环 for (; k <= K - stride; k += stride) { #pragma unroll for (int u = 0; u < UNROLL; ++u) { int offset = k + k0 + u * (WARP_SIZE * VEC_WIDTH); a_frag[u] = load_128b(&row_ptr[offset]); x_frag[u] = *reinterpret_cast(&x[offset]); } #pragma unroll for (int u = 0; u < UNROLL; ++u) { #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_frag[u].vals[i]) * static_cast(x_frag[u].vals[i]); } } } // Tail 循环 for (; k < K; k += WARP_SIZE * VEC_WIDTH) { int offset = k + k0; if (offset >= K) continue; bf16_x8 a_vec = load_128b(&row_ptr[offset]); bf16_x8 x_vec = *reinterpret_cast(&x[offset]); for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 负责写回 if (lane_id == 0) { y[m] = hip_bfloat16(alpha * sum); } return; } /** y = alpha * A^T * x + 0 * y * 单线程 vec + warp + shmem 缓存 x: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 算一个输出,warp 内归约。 * - shmem 缓存 x,分块加载。 */ template __global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; // 缓存 x 的一个 tile __shared__ hip_bfloat16 x_tile[TILE_K]; // 不会在 m>=M 时访问 A,因此不需要分支 const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; // 外层循环遍历 K 维度的所有 tile for (int kk = 0; kk < K; kk += TILE_K) { int tile_size = min(TILE_K, K - kk); // Step 1: 所有线程协作加载 x 的当前 tile 到 LDS // 每个线程加载 VEC_WIDTH 个元素 for (int i = threadIdx.x * VEC_WIDTH; i < tile_size; i += blockDim.x * VEC_WIDTH) { if (i + VEC_WIDTH <= tile_size) { // 完整的向量化加载 *reinterpret_cast(&x_tile[i]) = *reinterpret_cast(&x[kk + i]); } else { // Tail 循环逐个加载 for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) { x_tile[i + j] = x[kk + i + j]; } } } __syncthreads(); // Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算) if (m < M) { const int stride = WARP_SIZE * VEC_WIDTH; for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) { if (k + VEC_WIDTH <= tile_size) { // 完整的向量化计算 bf16_x8 a_vec = load_128b(&row_ptr[kk + k]); bf16_x8 x_vec = *reinterpret_cast(&x_tile[k]); #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } else { // Tail 循环 for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) { float val_a = static_cast(row_ptr[kk + k + i]); float val_x = static_cast(x_tile[k + i]); sum += val_a * val_x; } } } } __syncthreads(); } if (m >= M) return; // Warp 内归约 #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 写回结果 if (lane_id == 0) { y[m] = hip_bfloat16(alpha * sum); } return; } /** y = alpha * A^T * x + 0 * y * 单线程 vec + warp + 主循环 unroll + shmem 缓存 x: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 算一个输出,warp 内归约。 * - 主循环 unrolling。 * - shmem 缓存 x,分块加载。 */ template __global__ void gemv_bf16_TN_vec_warp_unroll_shm( int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; // 缓存 x 的一个 tile __shared__ hip_bfloat16 x_tile[TILE_K]; // 不会在 m>=M 时访问 A,因此不需要分支 const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; // 外层循环遍历 K 维度的所有 tile for (int kk = 0; kk < K; kk += TILE_K) { int tile_size = min(TILE_K, K - kk); // Step 1: 所有线程协作加载 x 的当前 tile 到 LDS // 每个线程加载 VEC_WIDTH 个元素 for (int i = threadIdx.x * VEC_WIDTH; i < tile_size; i += blockDim.x * VEC_WIDTH) { if (i + VEC_WIDTH <= tile_size) { // 完整的向量化加载 *reinterpret_cast(&x_tile[i]) = *reinterpret_cast(&x[kk + i]); } else { // Tail 循环逐个加载 for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) { x_tile[i + j] = x[kk + i + j]; } } } __syncthreads(); // Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算) if (m < M) { const int warp_stride = WARP_SIZE * VEC_WIDTH; const int unroll_stride = warp_stride * UNROLL; int k = lane_id * VEC_WIDTH; // 主循环:Unroll for (; k <= tile_size - unroll_stride; k += unroll_stride) { bf16_x8 a_frag[UNROLL]; bf16_x8 x_frag[UNROLL]; #pragma unroll for (int u = 0; u < UNROLL; ++u) { int current_k = k + u * warp_stride; a_frag[u] = load_128b(&row_ptr[kk + current_k]); x_frag[u] = *reinterpret_cast(&x_tile[current_k]); } #pragma unroll for (int u = 0; u < UNROLL; ++u) { #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_frag[u].vals[i]) * static_cast(x_frag[u].vals[i]); } } } // Tail 循环 for (; k < tile_size; k += warp_stride) { if (k + VEC_WIDTH <= tile_size) { // 完整的向量化计算 bf16_x8 a_vec = load_128b(&row_ptr[kk + k]); bf16_x8 x_vec = *reinterpret_cast(&x_tile[k]); #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } else { // Tail 循环 for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) { float val_a = static_cast(row_ptr[kk + k + i]); float val_x = static_cast(x_tile[k + i]); sum += val_a * val_x; } } } } __syncthreads(); } if (m >= M) return; // Warp 内归约 #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 写回结果 if (lane_id == 0) { y[m] = hip_bfloat16(alpha * sum); } return; } /** y = alpha * A^T * x + 0 * y * 单线程 vec + warp 处理多行 + shmem 缓存 x: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。 * - 每个 lane 维护 ROWS_PER_WARP 个累加器。 * - shmem 缓存 x,分块加载。 */ template __global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; // 每个 warp 处理 ROWS_PER_WARP 行 int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP + warp_id * ROWS_PER_WARP; // 缓存 x 的一个 tile __shared__ hip_bfloat16 x_tile[TILE_K]; // 每个 lane 维护 ROWS_PER_WARP 个累加器 float sum[ROWS_PER_WARP] = {0.0f}; // 预先计算每一行的指针 const hip_bfloat16 *row_ptr[ROWS_PER_WARP]; #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { int m = m_base + r; // 越界时指向 A,确保地址有效,消除后续分支 row_ptr[r] = (m < M) ? (A + m * lda) : A; } // 外层循环遍历 K 维度的所有 tile for (int kk = 0; kk < K; kk += TILE_K) { int tile_size = min(TILE_K, K - kk); // Step 1: 所有线程协作加载 x 的当前 tile 到 LDS for (int i = threadIdx.x * VEC_WIDTH; i < tile_size; i += blockDim.x * VEC_WIDTH) { if (i + VEC_WIDTH <= tile_size) { // 完整的向量化加载 *reinterpret_cast(&x_tile[i]) = *reinterpret_cast(&x[kk + i]); } else { // Tail 循环逐个加载 for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) { x_tile[i + j] = x[kk + i + j]; } } } __syncthreads(); // Step 2: 计算当前 tile 的贡献 // 每个 lane 处理 ROWS_PER_WARP 行 const int stride = WARP_SIZE * VEC_WIDTH; for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) { if (k + VEC_WIDTH <= tile_size) { // 完整的向量化计算 bf16_x8 x_vec = *reinterpret_cast(&x_tile[k]); bf16_x8 a_vecs[ROWS_PER_WARP]; // 批量加载,无分支 #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { a_vecs[r] = load_128b(&row_ptr[r][kk + k]); } // 批量计算 #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum[r] += static_cast(a_vecs[r].vals[i]) * static_cast(x_vec.vals[i]); } } } else { // Tail 循环 for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) { float val_x = static_cast(x_tile[k + i]); #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { float val_a = static_cast(row_ptr[r][kk + k + i]); sum[r] += val_a * val_x; } } } } __syncthreads(); } // Warp 内归约(每行独立归约) #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum[r] += __shfl_down(sum[r], offset); } // Lane 0 写回结果 if (lane_id == 0) { int m = m_base + r; if (m < M) { y[m] = hip_bfloat16(alpha * sum[r]); } } } return; } /// GEMV Microbenchmarks /// y = alpha * A^T * x + beta * y /// M = 输出维度 (11264) /// K = 归约维度 (4096) /// N = 1 int main(int argc, char **argv) { bool do_verify = false; float alpha = 1.0f; float beta = 0.0f; int M = 11264; int K = 4096; // int N = 1; // Unused int lda = K; int block_size = 256; if (char *value = getCmdOption(argv, argv + argc, "--verify")) { do_verify = std::stoi(value) == 1; } if (char *value = getCmdOption(argv, argv + argc, "--alpha")) { alpha = std::stof(value); } if (char *value = getCmdOption(argv, argv + argc, "-M")) { M = std::stoi(value); } if (char *value = getCmdOption(argv, argv + argc, "-K")) { K = std::stoi(value); lda = K; } if (char *value = getCmdOption(argv, argv + argc, "--lda")) { lda = std::stoi(value); } if (char *value = getCmdOption(argv, argv + argc, "-B")) { block_size = std::stoi(value); } // transA=T,因此是行优先 size_t count_A = (size_t)M * lda; size_t size_A = count_A * sizeof(hip_bfloat16); size_t size_x = (size_t)K * sizeof(hip_bfloat16); size_t size_y = (size_t)M * sizeof(hip_bfloat16); // Host 内存分配 std::vector h_A(count_A); std::vector h_x(K); std::vector h_y(M); // 随机初始数据 const float rand_max = static_cast(RAND_MAX); for (int i = 0; i < count_A; i++) h_A[i] = hip_bfloat16(static_cast(rand()) / rand_max); for (int i = 0; i < K; i++) h_x[i] = hip_bfloat16(static_cast(rand()) / rand_max); for (int i = 0; i < M; i++) h_y[i] = hip_bfloat16(0.0f); // Device 内存分配 hip_bfloat16 *d_A, *d_x, *d_y; checkHipErrors(hipMalloc(&d_A, size_A)); checkHipErrors(hipMalloc(&d_x, size_x)); checkHipErrors(hipMalloc(&d_y, size_y)); checkHipErrors(hipMemcpy(d_A, h_A.data(), size_A, hipMemcpyHostToDevice)); checkHipErrors(hipMemcpy(d_x, h_x.data(), size_x, hipMemcpyHostToDevice)); checkHipErrors(hipMemcpy(d_y, h_y.data(), size_y, hipMemcpyHostToDevice)); // Kernel 注册表 std::vector kernels; constexpr bool NTL = true; constexpr int UNROLL = 4; constexpr int TILE_K = calculate_tile_k<8>(4); constexpr int ROWS_PER_WARP = 2; kernels.push_back( {"naive", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int grid = (M + block_size - 1) / block_size; gemv_bf16_TN_naive<<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int grid = (M + block_size - 1) / block_size; gemv_bf16_TN_vec<<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int grid = (M + block_size - 1) / block_size; gemv_bf16_TN_vec <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_warp<<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp<<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp_mr" + std::to_string(ROWS_PER_WARP), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_mr <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_mr <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp+unroll" + std::to_string(UNROLL), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_unroll <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp+unroll" + std::to_string(UNROLL), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_unroll <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp+unroll" + std::to_string(UNROLL) + "+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_unroll_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp+unroll" + std::to_string(UNROLL) + "+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_unroll_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_mr_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); kernels.push_back( {"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" + std::to_string(TILE_K), [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda, const hip_bfloat16 *x, float beta, hip_bfloat16 *y) { int warps_per_block = block_size / WARP_SIZE; int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) / warps_per_block; gemv_bf16_TN_vec_warp_mr_shm <<>>(M, K, alpha, A, lda, x, beta, y); }}); // 运行所有测试 run_benchmark(kernels, M, K, alpha, d_A, lda, d_x, beta, d_y, do_verify); // 清理 checkHipErrors(hipFree(d_A)); checkHipErrors(hipFree(d_x)); checkHipErrors(hipFree(d_y)); return 0; }