Commit 8ac49790 authored by one's avatar one
Browse files

Update GEMV benchmarks, move to a separate dir

parent b3a56179
HIPCC ?= hipcc
CXXFLAGS ?= -std=c++17 -O3
OFFLOAD_ARCH ?= gfx936
TARGET := gemv_bench
SRC := gemv_bf16.cpp
DEP := gemv_utils.h
.PHONY: all clean
all: $(TARGET)
$(TARGET): $(SRC) $(DEP)
$(HIPCC) $(CXXFLAGS) --offload-arch=$(OFFLOAD_ARCH) $< -o $@
clean:
rm -f $(TARGET)
/**
* 模仿 GEMM 接口的 GEMV,即 N=1。
* 编译:
* hipcc -std=c++17 -O3 --offload-arch=gfx936 gemv_bf16.hip -o gemv_bench
* 执行:
* HIP_VISIBLE_DEVICES=4 numactl -N 0 -m 0 ./gemv_bench -M 11264 -K 4096
*/
#include "gemv_utils.h"
#define WARP_SIZE 64
#define VEC_WIDTH 8
#define OFFSET(i, j, lda) ((i) + (j) * (lda))
#define OFFSET_T(i, j, lda) ((i) * (lda) + (j))
/**
* 根据需求的并发 block 数量计算 shmem 用量(即 TILE_K 指定的 BF16 元素个数)
*
* AlignElements 为对齐粒度,即元素个数,默认 128-bit 对齐。
* - 8: 对齐到 128-bit (可能有利于 load128b)
* - 16: 对齐到 256-bit (某些 MFMA 指令需求)
*/
template <int AlignElements = 8>
constexpr int calculate_tile_k(int concurrent_blocks) {
// 安全检查
if (concurrent_blocks < 1)
concurrent_blocks = 1;
// 直接切分 LDS
constexpr int MAX_LDS_BYTES_PER_CU = 65536;
int bytes_per_block = MAX_LDS_BYTES_PER_CU / concurrent_blocks;
// 转为元素个数
int max_elements = bytes_per_block / sizeof(hip_bfloat16);
// 对齐
return (max_elements / AlignElements) * AlignElements;
}
/// 辅助结构体:把 float4 (128位) 重新解释为 8 个 bf16
struct __align__(16) bf16_x8 {
hip_bfloat16 vals[VEC_WIDTH];
};
/// 替代 float4
typedef float __attribute__((ext_vector_type(4))) float4_native;
/// 128-bit non-temporal load 或者 cached load
template <bool USE_NTL = false>
__device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) {
if constexpr (USE_NTL) {
// 把地址转换为 float4_native 指针
const float4_native *ptr = reinterpret_cast<const float4_native *>(src);
// 使用 Clang 内置 non-temporal load 函数,生成带有 slc/nt 修饰符的加载指令
float4_native tmp = __builtin_nontemporal_load(ptr);
// 把加载到的 128 位数据重新解释为 bf16_x8
return *reinterpret_cast<bf16_x8 *>(&tmp);
} else {
return *reinterpret_cast<const bf16_x8 *>(src);
}
}
/** y = alpha * A^T * x + 0 * y
* Naive 实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代
*/
__global__ void gemv_bf16_TN_naive(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = 0; k < K; k++) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* 向量化实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代。
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 每次读 VEC_WIDTH 个数据
for (int k = 0; k < K; k += VEC_WIDTH) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* Warp 归约:
* - JKI
* - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。
* - Warp 内归约。
*/
__global__ void gemv_bf16_TN_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id; k < K; k += stride) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* Vec + warp:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
*/
template <bool USE_NTL = false, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP] = {0.0f};
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
*/
template <bool USE_NTL = false, int UNROLL = 4>
__global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH * UNROLL;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 主循环临时变量
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
int k0 = lane_id * VEC_WIDTH;
int k = 0;
// 主循环
for (; k <= K - stride; k += stride) {
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int offset = k + k0 + u * (WARP_SIZE * VEC_WIDTH);
a_frag[u] = load_128b<USE_NTL>(&row_ptr[offset]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < K; k += WARP_SIZE * VEC_WIDTH) {
int offset = k + k0;
if (offset >= K)
continue;
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[offset]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int UNROLL = 4, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_unroll_shm(
int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A,
int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int warp_stride = WARP_SIZE * VEC_WIDTH;
const int unroll_stride = warp_stride * UNROLL;
int k = lane_id * VEC_WIDTH;
// 主循环:Unroll
for (; k <= tile_size - unroll_stride; k += unroll_stride) {
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int current_k = k + u * warp_stride;
a_frag[u] = load_128b<USE_NTL>(&row_ptr[kk + current_k]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x_tile[current_k]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < tile_size; k += warp_stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行 + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP] = {0.0f};
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献
// 每个 lane 处理 ROWS_PER_WARP 行
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][kk + k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_x = static_cast<float>(x_tile[k + i]);
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
float val_a = static_cast<float>(row_ptr[r][kk + k + i]);
sum[r] += val_a * val_x;
}
}
}
}
__syncthreads();
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/// GEMV Microbenchmarks
/// y = alpha * A^T * x + beta * y
/// M = 输出维度 (11264)
/// K = 归约维度 (4096)
/// N = 1
int main(int argc, char **argv) {
bool do_verify = false;
float alpha = 1.0f;
float beta = 0.0f;
int M = 11264;
int K = 4096;
// int N = 1; // Unused
int lda = K;
int block_size = 256;
if (char *value = getCmdOption(argv, argv + argc, "--verify")) {
do_verify = std::stoi(value) == 1;
}
if (char *value = getCmdOption(argv, argv + argc, "--alpha")) {
alpha = std::stof(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-M")) {
M = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-K")) {
K = std::stoi(value);
lda = K;
}
if (char *value = getCmdOption(argv, argv + argc, "--lda")) {
lda = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-B")) {
block_size = std::stoi(value);
}
// transA=T,因此是行优先
size_t count_A = (size_t)M * lda;
size_t size_A = count_A * sizeof(hip_bfloat16);
size_t size_x = (size_t)K * sizeof(hip_bfloat16);
size_t size_y = (size_t)M * sizeof(hip_bfloat16);
// Host 内存分配
std::vector<hip_bfloat16> h_A(count_A);
std::vector<hip_bfloat16> h_x(K);
std::vector<hip_bfloat16> h_y(M);
// 随机初始数据
const float rand_max = static_cast<float>(RAND_MAX);
for (int i = 0; i < count_A; i++)
h_A[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < K; i++)
h_x[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < M; i++)
h_y[i] = hip_bfloat16(0.0f);
// Device 内存分配
hip_bfloat16 *d_A, *d_x, *d_y;
checkHipErrors(hipMalloc(&d_A, size_A));
checkHipErrors(hipMalloc(&d_x, size_x));
checkHipErrors(hipMalloc(&d_y, size_y));
checkHipErrors(hipMemcpy(d_A, h_A.data(), size_A, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_x, h_x.data(), size_x, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_y, h_y.data(), size_y, hipMemcpyHostToDevice));
// Kernel 注册表
std::vector<KernelCase> kernels;
constexpr bool NTL = true;
constexpr int UNROLL = 4;
constexpr int TILE_K = calculate_tile_k<8>(4);
constexpr int ROWS_PER_WARP = 2;
kernels.push_back(
{"naive", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_naive<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8_ntl",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<NTL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8+warp",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x,
beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<NTL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<!NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<!NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<!NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<!NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<!NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
// 运行所有测试
run_benchmark(kernels, M, K, alpha, d_A, lda, d_x, beta, d_y, do_verify);
// 清理
checkHipErrors(hipFree(d_A));
checkHipErrors(hipFree(d_x));
checkHipErrors(hipFree(d_y));
return 0;
}
\ No newline at end of file
CXX ?= hipcc
CXX_FLAGS ?= -std=c++17 -O3
GPU_ARCH ?= gfx936
TARGET := gemv_bench
SRC := main.cpp
DEP := gemv_bf16.h gemv_utils.h hip_compat.h
.PHONY: all clean
all: $(TARGET)
# 根据 CXX 变量判断编译器类型
ifneq (,$(findstring nvcc,$(CXX)))
# NVCC 编译
$(TARGET): $(SRC) $(DEP)
$(CXX) $(CXX_FLAGS) -arch=$(GPU_ARCH) -x cu $< -o $@
else
# HIPCC 编译
$(TARGET): $(SRC) $(DEP)
$(CXX) $(CXX_FLAGS) --offload-arch=$(GPU_ARCH) $< -o $@
endif
clean:
rm -f $(TARGET)
GEMV Benchmarks
---------------
模仿 GEMM 接口的 GEMV,即 N=1,实现 BF16 版本。这些矩阵形状来自于 Evo2 推理过程。
计算公式:y = alpha * A^T * x + beta * y
M: 输出维度,例如 11264
K: 归约维度,例如 4096
N: 始终为 1
beta: 始终为 0
## Build
```bash
# 使用 HIPCC:
CXX=hipcc make GPU_ARCH=gfx936
# 使用 NVCC:
CXX=nvcc make GPU_ARCH=sm_80
```
## Run
```bash
# BW系列:
HIP_VISIBLE_DEVICES=1 numactl -N 0 -m 0 ./gemv_bench -M 11264 -K 4096
# A800:
./gemv_bench -M 11264 -K 4096
```
\ No newline at end of file
#pragma once
#include "gemv_utils.h"
// Warp Size 根据架构自动选择
#if defined(__HIP_PLATFORM_AMD__)
#define WARP_SIZE 64 // DCU
#else
#define WARP_SIZE 32 // NVIDIA
#endif
#define VEC_WIDTH 8
#define OFFSET(i, j, lda) ((i) + (j) * (lda))
#define OFFSET_T(i, j, lda) ((i) * (lda) + (j))
/**
* 根据需求的并发 block 数量计算 shmem 用量(即 TILE_K 指定的 BF16 元素个数)
*
* AlignElements 为对齐粒度,即元素个数,默认 128-bit 对齐。
* - 8: 对齐到 128-bit (可能有利于 load128b)
* - 16: 对齐到 256-bit (某些 MFMA 指令需求)
*/
template <int AlignElements = 8>
constexpr int calculate_tile_k(int concurrent_blocks) {
// 安全检查
if (concurrent_blocks < 1)
concurrent_blocks = 1;
// 直接切分 LDS
constexpr int MAX_LDS_BYTES_PER_CU = 65536;
int bytes_per_block = MAX_LDS_BYTES_PER_CU / concurrent_blocks;
// 转为元素个数
int max_elements = bytes_per_block / sizeof(hip_bfloat16);
// 对齐
return (max_elements / AlignElements) * AlignElements;
}
/// 辅助结构体:把 float4 (128位) 重新解释为 8 个 bf16
struct __align__(16) bf16_x8 {
hip_bfloat16 vals[VEC_WIDTH];
};
#if !defined(__NVCC__) && !defined(__CUDACC__)
/// 替代 float4,因为 non-temporal load 需要基本类型
typedef float __attribute__((ext_vector_type(4))) float4_native;
#endif
/// 128-bit non-temporal load 或者 cached load
template <bool USE_NTL = false>
__device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) {
#if defined(__NVCC__) || defined(__CUDACC__)
// NVIDIA 平台:直接使用普通加载
// NVCC 的优化器通常会自动选择合适的加载指令(如 LDG)
// 如果需要显式控制,可以使用 __ldg() 或 PTX 内联汇编
return *reinterpret_cast<const bf16_x8 *>(src);
#else
if constexpr (USE_NTL) {
// DCU:使用 Clang 内置 non-temporal load 函数
// 把地址转换为 float4_native 指针
const float4_native *ptr = reinterpret_cast<const float4_native *>(src);
// 使用 Clang 内置 non-temporal load 函数,生成带有 slc/nt 修饰符的加载指令
float4_native tmp = __builtin_nontemporal_load(ptr);
// 把加载到的 128 位数据重新解释为 bf16_x8
return *reinterpret_cast<bf16_x8 *>(&tmp);
} else {
return *reinterpret_cast<const bf16_x8 *>(src);
}
#endif
}
/** y = alpha * A^T * x + 0 * y
* Naive 实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代
*/
__global__ void gemv_bf16_TN_naive(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = 0; k < K; k++) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* 向量化实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代。
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 每次读 VEC_WIDTH 个数据
for (int k = 0; k < K; k += VEC_WIDTH) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* Warp 归约:
* - JKI
* - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。
* - Warp 内归约。
*/
__global__ void gemv_bf16_TN_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id; k < K; k += stride) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* Vec + warp:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
*/
template <bool USE_NTL = false, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP] = {0.0f};
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
*/
template <bool USE_NTL = false, int UNROLL = 4>
__global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH * UNROLL;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 主循环临时变量
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
int k0 = lane_id * VEC_WIDTH;
int k = 0;
// 主循环
for (; k <= K - stride; k += stride) {
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int offset = k + k0 + u * (WARP_SIZE * VEC_WIDTH);
a_frag[u] = load_128b<USE_NTL>(&row_ptr[offset]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < K; k += WARP_SIZE * VEC_WIDTH) {
int offset = k + k0;
if (offset >= K)
continue;
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[offset]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int UNROLL = 4, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_unroll_shm(
int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A,
int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int warp_stride = WARP_SIZE * VEC_WIDTH;
const int unroll_stride = warp_stride * UNROLL;
int k = lane_id * VEC_WIDTH;
// 主循环:Unroll
for (; k <= tile_size - unroll_stride; k += unroll_stride) {
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int current_k = k + u * warp_stride;
a_frag[u] = load_128b<USE_NTL>(&row_ptr[kk + current_k]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x_tile[current_k]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < tile_size; k += warp_stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行 + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
const float beta, // 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP] = {0.0f};
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献
// 每个 lane 处理 ROWS_PER_WARP 行
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][kk + k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_x = static_cast<float>(x_tile[k + i]);
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
float val_a = static_cast<float>(row_ptr[r][kk + k + i]);
sum[r] += val_a * val_x;
}
}
}
}
__syncthreads();
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
#pragma once
#include "hip_compat.h"
#include <algorithm>
#include <functional>
#include <hip/hip_bfloat16.h>
#include <hip/hip_runtime.h>
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
// --------------------------------------------------------------------------------
// ============================================================================
// Error Handling
// --------------------------------------------------------------------------------
// ============================================================================
inline void checkHipErrors(hipError_t result) {
if (result != hipSuccess) {
......@@ -20,9 +19,21 @@ inline void checkHipErrors(hipError_t result) {
}
}
// --------------------------------------------------------------------------------
// ============================================================================
// Device Info
// ============================================================================
/// L2 cache size in MB
inline int get_l2_cache_size(int device = 0) {
hipDeviceProp_t prop;
checkHipErrors(hipGetDeviceProperties(&prop, device));
return prop.l2CacheSize / 1024 / 1024;
}
// ============================================================================
// Command Line Parsing
// --------------------------------------------------------------------------------
// ============================================================================
inline char *getCmdOption(char **begin, char **end, const std::string &option) {
char **itr = std::find(begin, end, option);
......@@ -32,9 +43,9 @@ inline char *getCmdOption(char **begin, char **end, const std::string &option) {
return 0;
}
// --------------------------------------------------------------------------------
// ============================================================================
// CPU Reference & Verification
// --------------------------------------------------------------------------------
// ============================================================================
inline void gemv_cpu(int M, int K, float alpha, const hip_bfloat16 *h_A,
int lda, const hip_bfloat16 *h_x, float beta,
......@@ -46,7 +57,7 @@ inline void gemv_cpu(int M, int K, float alpha, const hip_bfloat16 *h_A,
float val_x = static_cast<float>(h_x[k]);
sum += val_a * val_x;
}
h_y[m] = hip_bfloat16(alpha * sum + beta * h_y[m]);
h_y[m] = hip_bfloat16(alpha * sum + beta * static_cast<float>(h_y[m]));
}
return;
......@@ -81,9 +92,9 @@ inline bool verify_result(int M, const hip_bfloat16 *h_y_gpu,
return true;
}
// --------------------------------------------------------------------------------
// ============================================================================
// Benchmark Framework
// --------------------------------------------------------------------------------
// ============================================================================
// 定义统一的 Kernel Launcher 签名
using KernelLauncher = std::function<void(
......@@ -131,6 +142,8 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
printf("%s\n", std::string(w_table, '-').c_str());
printf("M=%d, K=%d, N=1\n", M, K);
printf("lda=%d\n", lda);
printf("sizeof(A)=%lu MB\n", M * lda * sizeof(hip_bfloat16) / 1024 / 1024);
printf("L2 cache=%d MB\n", get_l2_cache_size());
printf("%s\n", std::string(w_table, '-').c_str());
printf("%-38s %10s %10s %10s %8s\n", "Kernel Name", "Time (us)", "GFLOPS",
"BW (GB/s)", "Result");
......@@ -165,7 +178,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
checkHipErrors(hipDeviceSynchronize());
// 3. Timing
int num_runs = 1000;
int num_runs = 100;
checkHipErrors(hipEventRecord(start));
for (int i = 0; i < num_runs; ++i) {
k.func(M, K, alpha, A, lda, x, beta, y);
......
#pragma once
/**
* HIP 到 CUDA 的兼容层
*
* 使用 nvcc 编译时,自动将 HIP API 映射到 CUDA API
* 使用 hipcc 编译时,使用原生 HIP 头文件
*/
#if defined(__NVCC__) || defined(__CUDACC__)
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <stdio.h>
// Runtime API 映射
#define hipMalloc cudaMalloc
#define hipFree cudaFree
#define hipMemcpy cudaMemcpy
#define hipMemcpyHostToDevice cudaMemcpyHostToDevice
#define hipMemcpyDeviceToHost cudaMemcpyDeviceToHost
#define hipMemset cudaMemset
#define hipDeviceSynchronize cudaDeviceSynchronize
#define hipGetDeviceProperties cudaGetDeviceProperties
#define hipGetErrorString cudaGetErrorString
// Event API 映射
#define hipEvent_t cudaEvent_t
#define hipEventCreate cudaEventCreate
#define hipEventDestroy cudaEventDestroy
#define hipEventRecord cudaEventRecord
#define hipEventSynchronize cudaEventSynchronize
#define hipEventElapsedTime cudaEventElapsedTime
// 数据类型映射
#define hipDeviceProp_t cudaDeviceProp
#define hipError_t cudaError_t
#define hipSuccess cudaSuccess
// CUDA 使用 __nv_bfloat16,HIP 使用 hip_bfloat16
typedef __nv_bfloat16 hip_bfloat16;
// Shuffle 指令映射
// CUDA 9.0+ 需要使用带 _sync 后缀的版本,并传入 warp mask
// 0xffffffff 表示整个 warp 的所有线程都参与
#ifndef __shfl_down
#define __shfl_down(val, offset) __shfl_down_sync(0xffffffff, val, offset)
#endif
#else
#include <hip/hip_runtime.h>
#include <hip/hip_bfloat16.h>
#endif
#include "gemv_bf16.h"
int main(int argc, char **argv) {
bool do_verify = false;
float alpha = 1.0f;
float beta = 0.0f;
int M = 11264;
int K = 4096;
// int N = 1; // Unused
int lda = K;
int block_size = 256;
if (char *value = getCmdOption(argv, argv + argc, "--verify")) {
do_verify = std::stoi(value) == 1;
}
if (char *value = getCmdOption(argv, argv + argc, "--alpha")) {
alpha = std::stof(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-M")) {
M = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-K")) {
K = std::stoi(value);
lda = K;
}
if (char *value = getCmdOption(argv, argv + argc, "--lda")) {
lda = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-B")) {
block_size = std::stoi(value);
}
// transA=T,因此是行优先
size_t count_A = (size_t)M * lda;
size_t size_A = count_A * sizeof(hip_bfloat16);
size_t size_x = (size_t)K * sizeof(hip_bfloat16);
size_t size_y = (size_t)M * sizeof(hip_bfloat16);
// Host 内存分配
std::vector<hip_bfloat16> h_A(count_A);
std::vector<hip_bfloat16> h_x(K);
std::vector<hip_bfloat16> h_y(M);
// 随机初始数据
const float rand_max = static_cast<float>(RAND_MAX);
for (int i = 0; i < count_A; i++)
h_A[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < K; i++)
h_x[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < M; i++)
h_y[i] = hip_bfloat16(0.0f);
// Device 内存分配
hip_bfloat16 *d_A, *d_x, *d_y;
checkHipErrors(hipMalloc(&d_A, size_A));
checkHipErrors(hipMalloc(&d_x, size_x));
checkHipErrors(hipMalloc(&d_y, size_y));
checkHipErrors(hipMemcpy(d_A, h_A.data(), size_A, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_x, h_x.data(), size_x, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_y, h_y.data(), size_y, hipMemcpyHostToDevice));
// Kernel 注册表
std::vector<KernelCase> kernels;
constexpr bool NTL = true;
constexpr int UNROLL = 4;
constexpr int TILE_K = calculate_tile_k<8>(4);
constexpr int ROWS_PER_WARP = 2;
kernels.push_back(
{"naive", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_naive<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8_ntl",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<NTL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta,
y);
}});
kernels.push_back(
{"vec8+warp",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x,
beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp",
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<NTL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<!NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<!NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<!NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<!NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<!NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
kernels.push_back(
{"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid =
((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP + warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}});
// 运行所有测试
run_benchmark(kernels, M, K, alpha, d_A, lda, d_x, beta, d_y, do_verify);
// 清理
checkHipErrors(hipFree(d_A));
checkHipErrors(hipFree(d_x));
checkHipErrors(hipFree(d_y));
return 0;
}
\ No newline at end of file
#!/bin/bash
set -e
# BW150
export HIP_VISIBLE_DEVICES=1
BIND_CMD="numactl -N 0 -m 0"
make
if [[ "$*" == *"--trace"* ]]; then
PROF_CMD="hipprof --trace-off --pmc"
CXX=hipcc make
if [[ "$*" == *"--pmc"* ]]; then
PROF_CMD="hipprof --trace-off --pmc --pmc-type 3"
${PROF_CMD} -o log/pmc-k1 ${BIND_CMD} ./gemv_bench --verify 1 -M 11264 -K 4096
${PROF_CMD} -o log/pmc-k2 ${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 11264
${PROF_CMD} -o log/pmc-k3 ${BIND_CMD} ./gemv_bench --verify 1 -M 12288 -K 4096
${PROF_CMD} -o log/pmc-k4 ${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 4096
elif [[ "$*" == *"--trace"* ]]; then
PROF_CMD="hipprof --hip-trace"
${PROF_CMD} -o log/trace-k1 ${BIND_CMD} ./gemv_bench --verify 1 -M 11264 -K 4096
${PROF_CMD} -o log/trace-k2 ${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 11264
${PROF_CMD} -o log/trace-k3 ${BIND_CMD} ./gemv_bench --verify 1 -M 12288 -K 4096
${PROF_CMD} -o log/trace-k4 ${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 4096
else
${BIND_CMD} ./gemv_bench --verify 1 -M 11264 -K 4096
${BIND_CMD} ./gemv_bench --verify 1 -M 4096 -K 11264
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment