#pragma once #include "gemv_utils.h" // Warp Size 根据架构自动选择 #if defined(__HIP_PLATFORM_AMD__) #define WARP_SIZE 64 // Hygon/AMD #else #define WARP_SIZE 32 // Nvidia #endif #define VEC_WIDTH 8 #define OFFSET(i, j, lda) ((i) + (j) * (lda)) #define OFFSET_T(i, j, lda) ((i) * (lda) + (j)) /** * 平台相关的 Shared Memory / LDS */ #if defined(__HIP_PLATFORM_AMD__) // Hygon/AMD: 64KB LDS per CU constexpr int MAX_SHMEM_BYTES_PER_BLOCK = 65536; #else // Nvidia: 48KB constexpr int MAX_SHMEM_BYTES_PER_BLOCK = 49152; #endif /** * 根据需求的并发 block 数量计算 shmem 用量(即 TILE_K 指定的 BF16 元素个数) * * AlignElements 为对齐粒度,即元素个数,默认 128-bit 对齐。 * - 8: 对齐到 128-bit (可能有利于 load128b) * - 16: 对齐到 256-bit (某些 MFMA 指令需求) * * concurrent_blocks: 期望的并发 block 数(用于计算可用 shmem) * - Hygon/AMD: 表示每个 CU 上的并发 block 数 * - Nvidia: 设置为 1 即可(每个 block 独立使用 shmem) */ template constexpr int calculate_tile_k(int concurrent_blocks = 1) { // 安全检查 if (concurrent_blocks < 1) concurrent_blocks = 1; // 计算每个 block 可用的 shmem int bytes_per_block = MAX_SHMEM_BYTES_PER_BLOCK / concurrent_blocks; // 转为元素个数 int max_elements = bytes_per_block / sizeof(hip_bfloat16); // 对齐 return (max_elements / AlignElements) * AlignElements; } /// 辅助结构体:把 float4 (128位) 重新解释为 8 个 bf16 struct __align__(16) bf16_x8 { hip_bfloat16 vals[VEC_WIDTH]; }; #if !defined(__NVCC__) && !defined(__CUDACC__) /// 替代 float4,因为 non-temporal load 需要基本类型 typedef float __attribute__((ext_vector_type(4))) float4_native; #endif /// 128-bit non-temporal load 或者 cached load template __device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) { if constexpr (USE_NTL) { #if defined(__NVCC__) || defined(__CUDACC__) // Nvidia 平台:PTX 内联汇编实现 cache streaming (ld.global.cs) uint4 tmp; // 128-bit = 4 x 32-bit asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(tmp.x), "=r"(tmp.y), "=r"(tmp.z), "=r"(tmp.w) : "l"(src) : "memory"); return *reinterpret_cast(&tmp); #else // Hygon/AMD 平台:使用 Clang 内置 non-temporal load 函数 // 把地址转换为 float4_native 指针 const float4_native *ptr = reinterpret_cast(src); // 使用 Clang 内置 non-temporal load 函数,生成带有 slc/nt 修饰符的加载指令 float4_native tmp = __builtin_nontemporal_load(ptr); // 把加载到的 128 位数据重新解释为 bf16_x8 return *reinterpret_cast(&tmp); #endif } else { return *reinterpret_cast(src); } } /** y = alpha * A^T * x + beta * y * Naive 实现: * - JKI * - 每个线程算一个输出,即 I 循环的一次迭代 */ __global__ void gemv_bf16_TN_naive(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int m = blockIdx.x * blockDim.x + threadIdx.x; // output if (m >= M) return; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; for (int k = 0; k < K; k++) { float val_a = static_cast(row_ptr[k]); float val_x = static_cast(x[k]); sum += val_a * val_x; } float y_original = static_cast(y[m]); y[m] = hip_bfloat16(alpha * sum + beta * y_original); return; } /** y = alpha * A^T * x + beta * y * 向量化实现: * - JKI * - 每个线程算一个输出,即 I 循环的一次迭代。 * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 */ template __global__ void gemv_bf16_TN_vec(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int m = blockIdx.x * blockDim.x + threadIdx.x; // output if (m >= M) return; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; // 每次读 VEC_WIDTH 个数据 for (int k = 0; k < K; k += VEC_WIDTH) { bf16_x8 a_vec = load_128b(&row_ptr[k]); bf16_x8 x_vec = *reinterpret_cast(&x[k]); #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } float y_original = static_cast(y[m]); y[m] = hip_bfloat16(alpha * sum + beta * y_original); return; } /** y = alpha * A^T * x + beta * y * Warp 归约: * - JKI * - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。 * - Warp 内归约。 */ __global__ void gemv_bf16_TN_warp(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; if (m >= M) return; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; const int stride = WARP_SIZE; for (int k = lane_id; k < K; k += stride) { float val_a = static_cast(row_ptr[k]); float val_x = static_cast(x[k]); sum += val_a * val_x; } #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 负责写回 if (lane_id == 0) { float y_original = static_cast(y[m]); y[m] = hip_bfloat16(alpha * sum + beta * y_original); } return; } /** y = alpha * A^T * x + beta * y * Vec + warp: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 算一个输出,warp 内归约。 */ template __global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; if (m >= M) return; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; const int stride = WARP_SIZE * VEC_WIDTH; for (int k = lane_id * VEC_WIDTH; k < K; k += stride) { bf16_x8 a_vec = load_128b(&row_ptr[k]); bf16_x8 x_vec = *reinterpret_cast(&x[k]); #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 负责写回 if (lane_id == 0) { float y_original = static_cast(y[m]); y[m] = hip_bfloat16(alpha * sum + beta * y_original); } return; } /** y = alpha * A^T * x + beta * y * 单线程 vec + warp 处理多行: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。 * - 每个 lane 维护 ROWS_PER_WARP 个累加器。 */ template __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; // 每个 warp 处理 ROWS_PER_WARP 行 int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP + warp_id * ROWS_PER_WARP; // 预先计算每一行的指针和原始 y 值 const hip_bfloat16 *row_ptr[ROWS_PER_WARP]; float y_original[ROWS_PER_WARP]; // 每个 lane 维护 ROWS_PER_WARP 个累加器 float sum[ROWS_PER_WARP] = {0.0f}; #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { int m = m_base + r; // 越界时指向 A,确保地址有效,消除后续分支 row_ptr[r] = (m < M) ? (A + m * lda) : A; // 读取有效的原始 y 值 y_original[r] = (m < M) ? static_cast(y[m]) : 0.0f; } const int stride = WARP_SIZE * VEC_WIDTH; for (int k = lane_id * VEC_WIDTH; k < K; k += stride) { bf16_x8 x_vec = *reinterpret_cast(&x[k]); bf16_x8 a_vecs[ROWS_PER_WARP]; // 批量加载,无分支 #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { a_vecs[r] = load_128b(&row_ptr[r][k]); } // 批量计算 #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum[r] += static_cast(a_vecs[r].vals[i]) * static_cast(x_vec.vals[i]); } } } // Warp 内归约(每行独立归约) #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum[r] += __shfl_down(sum[r], offset); } // Lane 0 写回结果 if (lane_id == 0) { int m = m_base + r; if (m < M) { y[m] = hip_bfloat16(alpha * sum[r] + beta * y_original[r]); } } } return; } /** y = alpha * A^T * x + beta * y * 单线程 vec + warp + 主循环 unroll: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 算一个输出,warp 内归约。 * - 主循环 unrolling。 */ template __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; if (m >= M) return; const hip_bfloat16 *row_ptr = A + m * lda; float sum = 0.0f; // 主循环临时变量 bf16_x8 a_frag[UNROLL]; bf16_x8 x_frag[UNROLL]; int k0 = lane_id * VEC_WIDTH; int k = 0; // 主循环 const int stride = WARP_SIZE * VEC_WIDTH * UNROLL; for (; k <= K - stride; k += stride) { #pragma unroll for (int u = 0; u < UNROLL; ++u) { int offset = k + k0 + u * (WARP_SIZE * VEC_WIDTH); a_frag[u] = load_128b(&row_ptr[offset]); x_frag[u] = *reinterpret_cast(&x[offset]); } #pragma unroll for (int u = 0; u < UNROLL; ++u) { #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_frag[u].vals[i]) * static_cast(x_frag[u].vals[i]); } } } // Tail 循环 for (; k < K; k += WARP_SIZE * VEC_WIDTH) { int offset = k + k0; if (offset >= K) continue; bf16_x8 a_vec = load_128b(&row_ptr[offset]); bf16_x8 x_vec = *reinterpret_cast(&x[offset]); for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 负责写回 if (lane_id == 0) { float y_original = static_cast(y[m]); y[m] = hip_bfloat16(alpha * sum + beta * y_original); } return; } /** y = alpha * A^T * x + beta * y * 单线程 vec + warp + shmem 缓存 x: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 算一个输出,warp 内归约。 * - shmem 缓存 x,分块加载。 */ template __global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; // 缓存 x 的一个 tile __shared__ hip_bfloat16 x_tile[TILE_K]; // 预先计算每一行的指针和原始 y 值 const hip_bfloat16 *row_ptr = A + m * lda; // 不需要分支 float y_original = (m < M) ? static_cast(y[m]) : 0.0f; float sum = 0.0f; // 外层循环遍历 K 维度的所有 tile for (int kk = 0; kk < K; kk += TILE_K) { int tile_size = min(TILE_K, K - kk); // Step 1: 所有线程协作加载 x 的当前 tile 到 LDS // 每个线程加载 VEC_WIDTH 个元素 for (int i = threadIdx.x * VEC_WIDTH; i < tile_size; i += blockDim.x * VEC_WIDTH) { if (i + VEC_WIDTH <= tile_size) { // 完整的向量化加载 *reinterpret_cast(&x_tile[i]) = *reinterpret_cast(&x[kk + i]); } else { // Tail 循环逐个加载 for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) { x_tile[i + j] = x[kk + i + j]; } } } __syncthreads(); // Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算) if (m < M) { const int stride = WARP_SIZE * VEC_WIDTH; for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) { if (k + VEC_WIDTH <= tile_size) { // 完整的向量化计算 bf16_x8 a_vec = load_128b(&row_ptr[kk + k]); bf16_x8 x_vec = *reinterpret_cast(&x_tile[k]); #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } else { // Tail 循环 for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) { float val_a = static_cast(row_ptr[kk + k + i]); float val_x = static_cast(x_tile[k + i]); sum += val_a * val_x; } } } } __syncthreads(); } if (m >= M) return; // Warp 内归约 #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 写回结果 if (lane_id == 0) { y[m] = hip_bfloat16(alpha * sum + beta * y_original); } return; } /** y = alpha * A^T * x + beta * y * 单线程 vec + warp + 主循环 unroll + shmem 缓存 x: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 算一个输出,warp 内归约。 * - 主循环 unrolling。 * - shmem 缓存 x,分块加载。 */ template __global__ void gemv_bf16_TN_vec_warp_unroll_shm( int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; // 缓存 x 的一个 tile __shared__ hip_bfloat16 x_tile[TILE_K]; // 预先计算每一行的指针和原始 y 值 const hip_bfloat16 *row_ptr = A + m * lda; // 不需要分支 float y_original = (m < M) ? static_cast(y[m]) : 0.0f; float sum = 0.0f; // 外层循环遍历 K 维度的所有 tile for (int kk = 0; kk < K; kk += TILE_K) { int tile_size = min(TILE_K, K - kk); // Step 1: 所有线程协作加载 x 的当前 tile 到 LDS // 每个线程加载 VEC_WIDTH 个元素 for (int i = threadIdx.x * VEC_WIDTH; i < tile_size; i += blockDim.x * VEC_WIDTH) { if (i + VEC_WIDTH <= tile_size) { // 完整的向量化加载 *reinterpret_cast(&x_tile[i]) = *reinterpret_cast(&x[kk + i]); } else { // Tail 循环逐个加载 for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) { x_tile[i + j] = x[kk + i + j]; } } } __syncthreads(); // Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算) if (m < M) { const int warp_stride = WARP_SIZE * VEC_WIDTH; const int unroll_stride = warp_stride * UNROLL; int k = lane_id * VEC_WIDTH; // 主循环:Unroll for (; k <= tile_size - unroll_stride; k += unroll_stride) { bf16_x8 a_frag[UNROLL]; bf16_x8 x_frag[UNROLL]; #pragma unroll for (int u = 0; u < UNROLL; ++u) { int current_k = k + u * warp_stride; a_frag[u] = load_128b(&row_ptr[kk + current_k]); x_frag[u] = *reinterpret_cast(&x_tile[current_k]); } #pragma unroll for (int u = 0; u < UNROLL; ++u) { #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_frag[u].vals[i]) * static_cast(x_frag[u].vals[i]); } } } // Tail 循环 for (; k < tile_size; k += warp_stride) { if (k + VEC_WIDTH <= tile_size) { // 完整的向量化计算 bf16_x8 a_vec = load_128b(&row_ptr[kk + k]); bf16_x8 x_vec = *reinterpret_cast(&x_tile[k]); #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum += static_cast(a_vec.vals[i]) * static_cast(x_vec.vals[i]); } } else { // Tail 循环 for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) { float val_a = static_cast(row_ptr[kk + k + i]); float val_x = static_cast(x_tile[k + i]); sum += val_a * val_x; } } } } __syncthreads(); } if (m >= M) return; // Warp 内归约 #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum += __shfl_down(sum, offset); } // Lane 0 写回结果 if (lane_id == 0) { y[m] = hip_bfloat16(alpha * sum + beta * y_original); } return; } /** y = alpha * A^T * x + beta * y * 单线程 vec + warp 处理多行 + shmem 缓存 x: * - JKI * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。 * - 每个 lane 维护 ROWS_PER_WARP 个累加器。 * - shmem 缓存 x,分块加载。 */ template __global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ x, const float beta, hip_bfloat16 *__restrict__ y) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; // 每个 warp 处理 ROWS_PER_WARP 行 int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP + warp_id * ROWS_PER_WARP; // 缓存 x 的一个 tile __shared__ hip_bfloat16 x_tile[TILE_K]; // 每个 lane 维护 ROWS_PER_WARP 个累加器 float sum[ROWS_PER_WARP] = {0.0f}; // 预先计算每一行的指针和原始 y 值 const hip_bfloat16 *row_ptr[ROWS_PER_WARP]; float y_original[ROWS_PER_WARP]; #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { int m = m_base + r; // 越界时指向 A,确保地址有效,消除后续分支 row_ptr[r] = (m < M) ? (A + m * lda) : A; // 读取有效的原始 y 值 y_original[r] = (m < M) ? static_cast(y[m]) : 0.0f; } // 外层循环遍历 K 维度的所有 tile for (int kk = 0; kk < K; kk += TILE_K) { int tile_size = min(TILE_K, K - kk); // Step 1: 所有线程协作加载 x 的当前 tile 到 LDS for (int i = threadIdx.x * VEC_WIDTH; i < tile_size; i += blockDim.x * VEC_WIDTH) { if (i + VEC_WIDTH <= tile_size) { // 完整的向量化加载 *reinterpret_cast(&x_tile[i]) = *reinterpret_cast(&x[kk + i]); } else { // Tail 循环逐个加载 for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) { x_tile[i + j] = x[kk + i + j]; } } } __syncthreads(); // Step 2: 计算当前 tile 的贡献 // 每个 lane 处理 ROWS_PER_WARP 行 const int stride = WARP_SIZE * VEC_WIDTH; for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) { if (k + VEC_WIDTH <= tile_size) { // 完整的向量化计算 bf16_x8 x_vec = *reinterpret_cast(&x_tile[k]); bf16_x8 a_vecs[ROWS_PER_WARP]; // 批量加载,无分支 #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { a_vecs[r] = load_128b(&row_ptr[r][kk + k]); } // 批量计算 #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { #pragma unroll for (int i = 0; i < VEC_WIDTH; ++i) { sum[r] += static_cast(a_vecs[r].vals[i]) * static_cast(x_vec.vals[i]); } } } else { // Tail 循环 for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) { float val_x = static_cast(x_tile[k + i]); #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { float val_a = static_cast(row_ptr[r][kk + k + i]); sum[r] += val_a * val_x; } } } } __syncthreads(); } // Warp 内归约(每行独立归约) #pragma unroll for (int r = 0; r < ROWS_PER_WARP; ++r) { #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { sum[r] += __shfl_down(sum[r], offset); } // Lane 0 写回结果 if (lane_id == 0) { int m = m_base + r; if (m < M) { y[m] = hip_bfloat16(alpha * sum[r] + beta * y_original[r]); } } } return; }