Commit b3a56179 authored by one's avatar one
Browse files

Update GEMV benchmarks

parent 977247a7
...@@ -3,7 +3,7 @@ CXXFLAGS ?= -std=c++17 -O3 ...@@ -3,7 +3,7 @@ CXXFLAGS ?= -std=c++17 -O3
OFFLOAD_ARCH ?= gfx936 OFFLOAD_ARCH ?= gfx936
TARGET := gemv_bench TARGET := gemv_bench
SRC := gemv_bf16.hip SRC := gemv_bf16.cpp
DEP := gemv_utils.h DEP := gemv_utils.h
.PHONY: all clean .PHONY: all clean
......
/**
* 模仿 GEMM 接口的 GEMV,即 N=1。
* 编译:
* hipcc -std=c++17 -O3 --offload-arch=gfx936 gemv_bf16.hip -o gemv_bench
* 执行:
* HIP_VISIBLE_DEVICES=4 numactl -N 0 -m 0 ./gemv_bench -M 11264 -K 4096
*/
#include "gemv_utils.h"
#define WARP_SIZE 64
#define VEC_WIDTH 8
#define OFFSET(i, j, lda) ((i) + (j) * (lda))
#define OFFSET_T(i, j, lda) ((i) * (lda) + (j))
/**
* 根据需求的并发 block 数量计算 shmem 用量(即 TILE_K 指定的 BF16 元素个数)
*
* AlignElements 为对齐粒度,即元素个数,默认 128-bit 对齐。
* - 8: 对齐到 128-bit (可能有利于 load128b)
* - 16: 对齐到 256-bit (某些 MFMA 指令需求)
*/
template <int AlignElements = 8>
constexpr int calculate_tile_k(int concurrent_blocks) {
// 安全检查
if (concurrent_blocks < 1)
concurrent_blocks = 1;
// 直接切分 LDS
constexpr int MAX_LDS_BYTES_PER_CU = 65536;
int bytes_per_block = MAX_LDS_BYTES_PER_CU / concurrent_blocks;
// 转为元素个数
int max_elements = bytes_per_block / sizeof(hip_bfloat16);
// 对齐
return (max_elements / AlignElements) * AlignElements;
}
/// 辅助结构体:把 float4 (128位) 重新解释为 8 个 bf16
struct __align__(16) bf16_x8 {
hip_bfloat16 vals[VEC_WIDTH];
};
/// 替代 float4
typedef float __attribute__((ext_vector_type(4))) float4_native;
/// 128-bit non-temporal load 或者 cached load
template <bool USE_NTL = false>
__device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) {
if constexpr (USE_NTL) {
// 把地址转换为 float4_native 指针
const float4_native *ptr = reinterpret_cast<const float4_native *>(src);
// 使用 Clang 内置 non-temporal load 函数,生成带有 slc/nt 修饰符的加载指令
float4_native tmp = __builtin_nontemporal_load(ptr);
// 把加载到的 128 位数据重新解释为 bf16_x8
return *reinterpret_cast<bf16_x8 *>(&tmp);
} else {
return *reinterpret_cast<const bf16_x8 *>(src);
}
}
/** y = alpha * A^T * x + 0 * y
* Naive 实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代
*/
__global__ void gemv_bf16_TN_naive(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = 0; k < K; k++) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* 向量化实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代。
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 每次读 VEC_WIDTH 个数据
for (int k = 0; k < K; k += VEC_WIDTH) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* Warp 归约:
* - JKI
* - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。
* - Warp 内归约。
*/
__global__ void gemv_bf16_TN_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id; k < K; k += stride) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* Vec + warp:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
*/
template <bool USE_NTL = false, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP] = {0.0f};
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
*/
template <bool USE_NTL = false, int UNROLL = 4>
__global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH * UNROLL;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 主循环临时变量
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
int k0 = lane_id * VEC_WIDTH;
int k = 0;
// 主循环
for (; k <= K - stride; k += stride) {
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int offset = k + k0 + u * (WARP_SIZE * VEC_WIDTH);
a_frag[u] = load_128b<USE_NTL>(&row_ptr[offset]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < K; k += WARP_SIZE * VEC_WIDTH) {
int offset = k + k0;
if (offset >= K)
continue;
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[offset]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int UNROLL = 4, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_unroll_shm(
int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A,
int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int warp_stride = WARP_SIZE * VEC_WIDTH;
const int unroll_stride = warp_stride * UNROLL;
int k = lane_id * VEC_WIDTH;
// 主循环:Unroll
for (; k <= tile_size - unroll_stride; k += unroll_stride) {
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int current_k = k + u * warp_stride;
a_frag[u] = load_128b<USE_NTL>(&row_ptr[kk + current_k]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x_tile[current_k]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < tile_size; k += warp_stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行 + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP] = {0.0f};
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献
// 每个 lane 处理 ROWS_PER_WARP 行
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][kk + k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_x = static_cast<float>(x_tile[k + i]);
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
float val_a = static_cast<float>(row_ptr[r][kk + k + i]);
sum[r] += val_a * val_x;
}
}
}
}
__syncthreads();
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/// GEMV Microbenchmarks
/// y = alpha * A^T * x + beta * y
/// M = 输出维度 (11264)
/// K = 归约维度 (4096)
/// N = 1
int main(int argc, char **argv) {
bool do_verify = false;
float alpha = 1.0f;
float beta = 0.0f;
int M = 11264;
int K = 4096;
// int N = 1; // Unused
int lda = K;
int block_size = 256;
if (char *value = getCmdOption(argv, argv + argc, "--verify")) {
do_verify = std::stoi(value) == 1;
}
if (char *value = getCmdOption(argv, argv + argc, "--alpha")) {
alpha = std::stof(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-M")) {
M = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-K")) {
K = std::stoi(value);
lda = K;
}
if (char *value = getCmdOption(argv, argv + argc, "--lda")) {
lda = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-B")) {
block_size = std::stoi(value);
}
// transA=T,因此是行优先
size_t count_A = (size_t)M * lda;
size_t size_A = count_A * sizeof(hip_bfloat16);
size_t size_x = (size_t)K * sizeof(hip_bfloat16);
size_t size_y = (size_t)M * sizeof(hip_bfloat16);
// Host 内存分配
std::vector<hip_bfloat16> h_A(count_A);
std::vector<hip_bfloat16> h_x(K);
std::vector<hip_bfloat16> h_y(M);
// 随机初始数据
const float rand_max = static_cast<float>(RAND_MAX);
for (int i = 0; i < count_A; i++)
h_A[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < K; i++)
h_x[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < M; i++)
h_y[i] = hip_bfloat16(0.0f);
// Device 内存分配
hip_bfloat16 *d_A, *d_x, *d_y;
checkHipErrors(hipMalloc(&d_A, size_A));
checkHipErrors(hipMalloc(&d_x, size_x));
checkHipErrors(hipMalloc(&d_y, size_y));
checkHipErrors(hipMemcpy(d_A, h_A.data(), size_A, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_x, h_x.data(), size_x, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_y, h_y.data(), size_y, hipMemcpyHostToDevice));
// Kernel 注册表
std::vector<KernelCase> kernels;
constexpr bool NTL = true;
constexpr int UNROLL = 4;
constexpr int TILE_K = calculate_tile_k<8>(4);
constexpr int ROWS_PER_WARP = 2;
kernels.push_back(
{"naive", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_naive<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8_ntl",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<NTL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8+warp",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x,
beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<NTL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<!NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<!NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<!NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<!NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<!NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
// 运行所有测试
run_benchmark(kernels, M, K, alpha, d_A, lda, d_x, beta, d_y, do_verify);
// 清理
checkHipErrors(hipFree(d_A));
checkHipErrors(hipFree(d_x));
checkHipErrors(hipFree(d_y));
return 0;
}
\ No newline at end of file
/**
* 模仿 GEMM 接口的 GEMV,即 N=1。
* 编译:
* hipcc -std=c++17 -O3 --offload-arch=gfx936 gemv_bf16.hip -o gemv_bench
* 执行:
* HIP_VISIBLE_DEVICES=4 numactl -N 0 -m 0 ./gemv_bench -M 11264 -K 4096
*/
#include "gemv_utils.h"
#define WARP_SIZE 64
#define VEC_WIDTH 8
#define OFFSET(i, j, lda) ((i) + (j) * (lda))
#define OFFSET_T(i, j, lda) ((i) * (lda) + (j))
// 辅助结构体:把 float4 (128位) 重新解释为 8 个 bf16
struct __align__(16) bf16_x8 {
hip_bfloat16 vals[VEC_WIDTH];
};
// 替代 float4
typedef float __attribute__((ext_vector_type(4))) float4_native;
// 128-bit non-temporal load 或者 cached load
template <bool USE_NTL = false>
__device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) {
if constexpr (USE_NTL) {
// 把地址转换为 float4_native 指针
const float4_native *ptr = reinterpret_cast<const float4_native *>(src);
// 使用 Clang 内置 non-temporal load 函数,生成带有 slc/nt 修饰符的加载指令
float4_native tmp = __builtin_nontemporal_load(ptr);
// 把加载到的 128 位数据重新解释为 bf16_x8
return *reinterpret_cast<bf16_x8 *>(&tmp);
} else {
return *reinterpret_cast<const bf16_x8 *>(src);
}
}
/** y = alpha * A^T * x + 0 * y
* Naive 实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代
*/
__global__ void gemv_bf16_TN_naive(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = 0; k < K; k++) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* 向量化实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代。
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 每次读 VEC_WIDTH 个数据
for (int k = 0; k < K; k += VEC_WIDTH) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* Warp 归约:
* - JKI
* - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。
* - Warp 内归约。
*/
__global__ void gemv_bf16_TN_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id; k < K; k += stride) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* Vec + warp:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
*/
template <bool USE_NTL = false, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
sum[r] = 0.0f;
}
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
*/
template <bool USE_NTL = false, int UNROLL = 4>
__global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH * UNROLL;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 主循环临时变量
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
int k0 = lane_id * VEC_WIDTH;
int k = 0;
// 主循环
for (; k <= K - stride; k += stride) {
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int offset = k + k0 + u * (WARP_SIZE * VEC_WIDTH);
a_frag[u] = load_128b<USE_NTL>(&row_ptr[offset]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < K; k += WARP_SIZE * VEC_WIDTH) {
int offset = k + k0;
if (offset >= K)
continue;
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[offset]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int UNROLL = 4, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_unroll_shm(
int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A,
int lda, const hip_bfloat16 *__restrict__ x, hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int warp_stride = WARP_SIZE * VEC_WIDTH;
const int unroll_stride = warp_stride * UNROLL;
int k = lane_id * VEC_WIDTH;
// 主循环:Unroll
for (; k <= tile_size - unroll_stride; k += unroll_stride) {
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int current_k = k + u * warp_stride;
a_frag[u] = load_128b<USE_NTL>(&row_ptr[kk + current_k]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x_tile[current_k]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < tile_size; k += warp_stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行 + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
sum[r] = 0.0f;
}
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献
// 每个 lane 处理 ROWS_PER_WARP 行
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][kk + k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_x = static_cast<float>(x_tile[k + i]);
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
float val_a = static_cast<float>(row_ptr[r][kk + k + i]);
sum[r] += val_a * val_x;
}
}
}
}
__syncthreads();
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/// GEMV Microbenchmarks
/// y = alpha * A^T * x + beta * y
/// M = 输出维度 (11264)
/// K = 归约维度 (4096)
/// N = 1
int main(int argc, char **argv) {
bool do_verify = false;
float alpha = 1.0f;
int M = 11264;
int K = 4096;
// int N = 1; // Unused
int lda = K;
int block_size = 256;
if (char *value = getCmdOption(argv, argv + argc, "--verify")) {
do_verify = std::stoi(value) == 1;
}
if (char *value = getCmdOption(argv, argv + argc, "--alpha")) {
alpha = std::stof(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-M")) {
M = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-K")) {
K = std::stoi(value);
lda = K;
}
if (char *value = getCmdOption(argv, argv + argc, "--lda")) {
lda = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-B")) {
block_size = std::stoi(value);
}
// transA=T,因此是行优先
size_t count_A = (size_t)M * lda;
size_t size_A = count_A * sizeof(hip_bfloat16);
size_t size_x = (size_t)K * sizeof(hip_bfloat16);
size_t size_y = (size_t)M * sizeof(hip_bfloat16);
// Host 内存分配
std::vector<hip_bfloat16> h_A(count_A);
std::vector<hip_bfloat16> h_x(K);
std::vector<hip_bfloat16> h_y(M);
// 随机初始数据
const float rand_max = static_cast<float>(RAND_MAX);
for (int i = 0; i < count_A; i++)
h_A[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < K; i++)
h_x[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < M; i++)
h_y[i] = hip_bfloat16(0.0f);
// Device 内存分配
hip_bfloat16 *d_A, *d_x, *d_y;
checkHipErrors(hipMalloc(&d_A, size_A));
checkHipErrors(hipMalloc(&d_x, size_x));
checkHipErrors(hipMalloc(&d_y, size_y));
checkHipErrors(hipMemcpy(d_A, h_A.data(), size_A, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_x, h_x.data(), size_x, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_y, h_y.data(), size_y, hipMemcpyHostToDevice));
// Kernel 注册表
std::vector<KernelCase> kernels;
constexpr bool NTL = true;
constexpr int UNROLL = 4;
constexpr int TILE_K = 4096;
constexpr int ROWS_PER_WARP = 2;
kernels.push_back(
{"naive", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_naive<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"vec8", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"vec8_ntl", [&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<NTL><<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"vec8+warp", [&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"vec8_ntl+warp", [&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<NTL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP +
warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<!NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP +
warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<!NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<!NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<!NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP +
warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<!NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP) +
"+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP +
warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
// 运行所有测试
run_benchmark(kernels, M, K, alpha, d_A, lda, d_x, d_y, do_verify);
// 清理
checkHipErrors(hipFree(d_A));
checkHipErrors(hipFree(d_x));
checkHipErrors(hipFree(d_y));
return 0;
}
\ No newline at end of file
...@@ -37,7 +37,8 @@ inline char *getCmdOption(char **begin, char **end, const std::string &option) { ...@@ -37,7 +37,8 @@ inline char *getCmdOption(char **begin, char **end, const std::string &option) {
// -------------------------------------------------------------------------------- // --------------------------------------------------------------------------------
inline void gemv_cpu(int M, int K, float alpha, const hip_bfloat16 *h_A, inline void gemv_cpu(int M, int K, float alpha, const hip_bfloat16 *h_A,
int lda, const hip_bfloat16 *h_x, hip_bfloat16 *h_y) { int lda, const hip_bfloat16 *h_x, float beta,
hip_bfloat16 *h_y) {
for (int m = 0; m < M; ++m) { for (int m = 0; m < M; ++m) {
float sum = 0.0f; float sum = 0.0f;
for (int k = 0; k < K; ++k) { for (int k = 0; k < K; ++k) {
...@@ -45,7 +46,7 @@ inline void gemv_cpu(int M, int K, float alpha, const hip_bfloat16 *h_A, ...@@ -45,7 +46,7 @@ inline void gemv_cpu(int M, int K, float alpha, const hip_bfloat16 *h_A,
float val_x = static_cast<float>(h_x[k]); float val_x = static_cast<float>(h_x[k]);
sum += val_a * val_x; sum += val_a * val_x;
} }
h_y[m] = hip_bfloat16(alpha * sum); h_y[m] = hip_bfloat16(alpha * sum + beta * h_y[m]);
} }
return; return;
...@@ -85,9 +86,9 @@ inline bool verify_result(int M, const hip_bfloat16 *h_y_gpu, ...@@ -85,9 +86,9 @@ inline bool verify_result(int M, const hip_bfloat16 *h_y_gpu,
// -------------------------------------------------------------------------------- // --------------------------------------------------------------------------------
// 定义统一的 Kernel Launcher 签名 // 定义统一的 Kernel Launcher 签名
using KernelLauncher = using KernelLauncher = std::function<void(
std::function<void(int M, int K, float alpha, const hip_bfloat16 *A, int M, int K, float alpha, const hip_bfloat16 *A, int lda,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y)>; const hip_bfloat16 *x, float beta, hip_bfloat16 *y)>;
struct KernelCase { struct KernelCase {
std::string name; std::string name;
...@@ -96,7 +97,7 @@ struct KernelCase { ...@@ -96,7 +97,7 @@ struct KernelCase {
inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K, inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
float alpha, const hip_bfloat16 *A, int lda, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, hip_bfloat16 *y, const hip_bfloat16 *x, float beta, hip_bfloat16 *y,
bool do_verify) { bool do_verify) {
std::cout << "GEMV Benchmarks" << std::endl; std::cout << "GEMV Benchmarks" << std::endl;
...@@ -120,7 +121,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K, ...@@ -120,7 +121,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
hipMemcpyDeviceToHost)); hipMemcpyDeviceToHost));
// 计算 CPU Reference // 计算 CPU Reference
gemv_cpu(M, K, alpha, h_A.data(), lda, h_x.data(), h_y_ref.data()); gemv_cpu(M, K, alpha, h_A.data(), lda, h_x.data(), beta, h_y_ref.data());
} }
// 列宽 // 列宽
...@@ -143,7 +144,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K, ...@@ -143,7 +144,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
checkHipErrors(hipMemset(y, 0, M * sizeof(hip_bfloat16))); checkHipErrors(hipMemset(y, 0, M * sizeof(hip_bfloat16)));
// 运行一次 // 运行一次
k.func(M, K, alpha, A, lda, x, y); k.func(M, K, alpha, A, lda, x, beta, y);
checkHipErrors(hipDeviceSynchronize()); checkHipErrors(hipDeviceSynchronize());
// 拷回结果 // 拷回结果
...@@ -159,7 +160,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K, ...@@ -159,7 +160,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
// 2. Warmup // 2. Warmup
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
k.func(M, K, alpha, A, lda, x, y); k.func(M, K, alpha, A, lda, x, beta, y);
} }
checkHipErrors(hipDeviceSynchronize()); checkHipErrors(hipDeviceSynchronize());
...@@ -167,7 +168,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K, ...@@ -167,7 +168,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
int num_runs = 1000; int num_runs = 1000;
checkHipErrors(hipEventRecord(start)); checkHipErrors(hipEventRecord(start));
for (int i = 0; i < num_runs; ++i) { for (int i = 0; i < num_runs; ++i) {
k.func(M, K, alpha, A, lda, x, y); k.func(M, K, alpha, A, lda, x, beta, y);
} }
checkHipErrors(hipEventRecord(stop)); checkHipErrors(hipEventRecord(stop));
checkHipErrors(hipEventSynchronize(stop)); checkHipErrors(hipEventSynchronize(stop));
...@@ -184,8 +185,8 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K, ...@@ -184,8 +185,8 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
double bytes_moved = (double)(M * K + K + M) * sizeof(hip_bfloat16); double bytes_moved = (double)(M * K + K + M) * sizeof(hip_bfloat16);
double bw = bytes_moved / (avg_ms * 1e-3) / 1e9; double bw = bytes_moved / (avg_ms * 1e-3) / 1e9;
printf("%-38s %10.1f %10.2f %10.2f %8s\n", k.name.c_str(), avg_ms * 1e3, gflops, printf("%-38s %10.1f %10.2f %10.2f %8s\n", k.name.c_str(), avg_ms * 1e3,
bw, result_status.c_str()); gflops, bw, result_status.c_str());
} }
std::cout << std::string(w_table, '-') << std::endl; std::cout << std::string(w_table, '-') << std::endl;
......
#!/bin/bash #!/bin/bash
# BW150
export HIP_VISIBLE_DEVICES=1
BIND_CMD="numactl -N 0 -m 0"
make make
# BW150 if [[ "$*" == *"--trace"* ]]; then
export HIP_VISIBLE_DEVICES=4 PROF_CMD="hipprof --trace-off --pmc"
${PROF_CMD} -o log/pmc-k1 ${BIND_CMD} ./gemv_bench --verify 1 -M 11264 -K 4096
hipprof numactl -N 0 -m 0 ./gemv_bench --verify 1 -M 11264 -K 4096 ${PROF_CMD} -o log/pmc-k2 ${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 11264
hipprof numactl -N 0 -m 0 ./gemv_bench --verify 1 -M 4096 -K 11264 ${PROF_CMD} -o log/pmc-k3 ${BIND_CMD} ./gemv_bench --verify 1 -M 12288 -K 4096
hipprof numactl -N 0 -m 0 ./gemv_bench --verify 1 -M 12288 -K 4096 ${PROF_CMD} -o log/pmc-k4 ${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 4096
hipprof numactl -N 0 -m 0 ./gemv_bench --verify 1 -M 4096 -K 4096 else
\ No newline at end of file ${BIND_CMD} ./gemv_bench --verify 1 -M 11264 -K 4096
${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 11264
${BIND_CMD} ./gemv_bench --verify 1 -M 12288 -K 4096
${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 4096
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment