"vscode:/vscode.git/clone" did not exist on "32ed692ec6abf07ed7e3ddab96def567882a5354"
Commit cc91f72b authored by one's avatar one
Browse files

Update GEMV kernels

parent a781cad3
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
* 平台相关的 Shared Memory / LDS * 平台相关的 Shared Memory / LDS
*/ */
#if defined(__HIP_PLATFORM_AMD__) #if defined(__HIP_PLATFORM_AMD__)
// Hygon/AMD: 64KB LDS per CU // Hygon/AMD: 64KB LDS per CU
constexpr int MAX_SHMEM_BYTES_PER_BLOCK = 65536; constexpr int MAX_SHMEM_BYTES_PER_BLOCK = 65536;
#else #else
// Nvidia: 48KB // Nvidia: 48KB
constexpr int MAX_SHMEM_BYTES_PER_BLOCK = 49152; constexpr int MAX_SHMEM_BYTES_PER_BLOCK = 49152;
#endif #endif
/** /**
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
* AlignElements 为对齐粒度,即元素个数,默认 128-bit 对齐。 * AlignElements 为对齐粒度,即元素个数,默认 128-bit 对齐。
* - 8: 对齐到 128-bit (可能有利于 load128b) * - 8: 对齐到 128-bit (可能有利于 load128b)
* - 16: 对齐到 256-bit (某些 MFMA 指令需求) * - 16: 对齐到 256-bit (某些 MFMA 指令需求)
* *
* concurrent_blocks: 期望的并发 block 数(用于计算可用 shmem) * concurrent_blocks: 期望的并发 block 数(用于计算可用 shmem)
* - Hygon/AMD: 表示每个 CU 上的并发 block 数 * - Hygon/AMD: 表示每个 CU 上的并发 block 数
* - Nvidia: 设置为 1 即可(每个 block 独立使用 shmem) * - Nvidia: 设置为 1 即可(每个 block 独立使用 shmem)
...@@ -91,7 +91,7 @@ __device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) { ...@@ -91,7 +91,7 @@ __device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) {
} }
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* Naive 实现: * Naive 实现:
* - JKI * - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代 * - 每个线程算一个输出,即 I 循环的一次迭代
...@@ -99,7 +99,7 @@ __device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) { ...@@ -99,7 +99,7 @@ __device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) {
__global__ void gemv_bf16_TN_naive(int M, int K, const float alpha, __global__ void gemv_bf16_TN_naive(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x, const hip_bfloat16 *__restrict__ x,
const float beta, // 0 const float beta,
hip_bfloat16 *__restrict__ y) { hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M) if (m >= M)
...@@ -113,12 +113,14 @@ __global__ void gemv_bf16_TN_naive(int M, int K, const float alpha, ...@@ -113,12 +113,14 @@ __global__ void gemv_bf16_TN_naive(int M, int K, const float alpha,
float val_x = static_cast<float>(x[k]); float val_x = static_cast<float>(x[k]);
sum += val_a * val_x; sum += val_a * val_x;
} }
y[m] = hip_bfloat16(alpha * sum);
float y_original = static_cast<float>(y[m]);
y[m] = hip_bfloat16(alpha * sum + beta * y_original);
return; return;
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* 向量化实现: * 向量化实现:
* - JKI * - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代。 * - 每个线程算一个输出,即 I 循环的一次迭代。
...@@ -128,7 +130,7 @@ template <bool USE_NTL = false> ...@@ -128,7 +130,7 @@ template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec(int M, int K, const float alpha, __global__ void gemv_bf16_TN_vec(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x, const hip_bfloat16 *__restrict__ x,
const float beta, // 0 const float beta,
hip_bfloat16 *__restrict__ y) { hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M) if (m >= M)
...@@ -149,12 +151,13 @@ __global__ void gemv_bf16_TN_vec(int M, int K, const float alpha, ...@@ -149,12 +151,13 @@ __global__ void gemv_bf16_TN_vec(int M, int K, const float alpha,
} }
} }
y[m] = hip_bfloat16(alpha * sum); float y_original = static_cast<float>(y[m]);
y[m] = hip_bfloat16(alpha * sum + beta * y_original);
return; return;
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* Warp 归约: * Warp 归约:
* - JKI * - JKI
* - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。 * - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。
...@@ -163,7 +166,7 @@ __global__ void gemv_bf16_TN_vec(int M, int K, const float alpha, ...@@ -163,7 +166,7 @@ __global__ void gemv_bf16_TN_vec(int M, int K, const float alpha,
__global__ void gemv_bf16_TN_warp(int M, int K, const float alpha, __global__ void gemv_bf16_TN_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda, const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x, const hip_bfloat16 *__restrict__ x,
const float beta, // 0 const float beta,
hip_bfloat16 *__restrict__ y) { hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
...@@ -172,9 +175,9 @@ __global__ void gemv_bf16_TN_warp(int M, int K, const float alpha, ...@@ -172,9 +175,9 @@ __global__ void gemv_bf16_TN_warp(int M, int K, const float alpha,
if (m >= M) if (m >= M)
return; return;
const int stride = WARP_SIZE;
const hip_bfloat16 *row_ptr = A + m * lda; const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f; float sum = 0.0f;
const int stride = WARP_SIZE;
for (int k = lane_id; k < K; k += stride) { for (int k = lane_id; k < K; k += stride) {
float val_a = static_cast<float>(row_ptr[k]); float val_a = static_cast<float>(row_ptr[k]);
...@@ -189,25 +192,25 @@ __global__ void gemv_bf16_TN_warp(int M, int K, const float alpha, ...@@ -189,25 +192,25 @@ __global__ void gemv_bf16_TN_warp(int M, int K, const float alpha,
// Lane 0 负责写回 // Lane 0 负责写回
if (lane_id == 0) { if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum); float y_original = static_cast<float>(y[m]);
y[m] = hip_bfloat16(alpha * sum + beta * y_original);
} }
return; return;
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* Vec + warp: * Vec + warp:
* - JKI * - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。 * - 每个 warp 算一个输出,warp 内归约。
*/ */
template <bool USE_NTL = false> template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha, __global__ void
const hip_bfloat16 *__restrict__ A, gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
int lda, const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x, const hip_bfloat16 *__restrict__ x, const float beta,
const float beta, // 0 hip_bfloat16 *__restrict__ y) {
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
...@@ -215,9 +218,9 @@ __global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha, ...@@ -215,9 +218,9 @@ __global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
if (m >= M) if (m >= M)
return; return;
const int stride = WARP_SIZE * VEC_WIDTH;
const hip_bfloat16 *row_ptr = A + m * lda; const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f; float sum = 0.0f;
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) { for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]); bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
...@@ -237,13 +240,14 @@ __global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha, ...@@ -237,13 +240,14 @@ __global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
// Lane 0 负责写回 // Lane 0 负责写回
if (lane_id == 0) { if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum); float y_original = static_cast<float>(y[m]);
y[m] = hip_bfloat16(alpha * sum + beta * y_original);
} }
return; return;
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* 单线程 vec + warp 处理多行: * 单线程 vec + warp 处理多行:
* - JKI * - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
...@@ -251,12 +255,11 @@ __global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha, ...@@ -251,12 +255,11 @@ __global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。 * - 每个 lane 维护 ROWS_PER_WARP 个累加器。
*/ */
template <bool USE_NTL = false, int ROWS_PER_WARP = 2> template <bool USE_NTL = false, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha, __global__ void
const hip_bfloat16 *__restrict__ A, gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
int lda, const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x, const hip_bfloat16 *__restrict__ x, const float beta,
const float beta, // 0 hip_bfloat16 *__restrict__ y) {
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
...@@ -264,16 +267,20 @@ __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha, ...@@ -264,16 +267,20 @@ __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP + int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP; warp_id * ROWS_PER_WARP;
// 预先计算每一行的指针和原始 y 值
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
float y_original[ROWS_PER_WARP];
// 每个 lane 维护 ROWS_PER_WARP 个累加器 // 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP] = {0.0f}; float sum[ROWS_PER_WARP] = {0.0f};
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll #pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) { for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r; int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支 // 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A; row_ptr[r] = (m < M) ? (A + m * lda) : A;
// 读取有效的原始 y 值
y_original[r] = (m < M) ? static_cast<float>(y[m]) : 0.0f;
} }
const int stride = WARP_SIZE * VEC_WIDTH; const int stride = WARP_SIZE * VEC_WIDTH;
...@@ -311,7 +318,7 @@ __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha, ...@@ -311,7 +318,7 @@ __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
if (lane_id == 0) { if (lane_id == 0) {
int m = m_base + r; int m = m_base + r;
if (m < M) { if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]); y[m] = hip_bfloat16(alpha * sum[r] + beta * y_original[r]);
} }
} }
} }
...@@ -319,7 +326,7 @@ __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha, ...@@ -319,7 +326,7 @@ __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
return; return;
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* 单线程 vec + warp + 主循环 unroll: * 单线程 vec + warp + 主循环 unroll:
* - JKI * - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
...@@ -327,12 +334,11 @@ __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha, ...@@ -327,12 +334,11 @@ __global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
* - 主循环 unrolling。 * - 主循环 unrolling。
*/ */
template <bool USE_NTL = false, int UNROLL = 4> template <bool USE_NTL = false, int UNROLL = 4>
__global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha, __global__ void
const hip_bfloat16 *__restrict__ A, gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
int lda, const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x, const hip_bfloat16 *__restrict__ x,
const float beta, // 0 const float beta, hip_bfloat16 *__restrict__ y) {
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
...@@ -340,7 +346,6 @@ __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha, ...@@ -340,7 +346,6 @@ __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
if (m >= M) if (m >= M)
return; return;
const int stride = WARP_SIZE * VEC_WIDTH * UNROLL;
const hip_bfloat16 *row_ptr = A + m * lda; const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f; float sum = 0.0f;
...@@ -352,6 +357,7 @@ __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha, ...@@ -352,6 +357,7 @@ __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
int k = 0; int k = 0;
// 主循环 // 主循环
const int stride = WARP_SIZE * VEC_WIDTH * UNROLL;
for (; k <= K - stride; k += stride) { for (; k <= K - stride; k += stride) {
#pragma unroll #pragma unroll
for (int u = 0; u < UNROLL; ++u) { for (int u = 0; u < UNROLL; ++u) {
...@@ -392,13 +398,14 @@ __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha, ...@@ -392,13 +398,14 @@ __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
// Lane 0 负责写回 // Lane 0 负责写回
if (lane_id == 0) { if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum); float y_original = static_cast<float>(y[m]);
y[m] = hip_bfloat16(alpha * sum + beta * y_original);
} }
return; return;
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* 单线程 vec + warp + shmem 缓存 x: * 单线程 vec + warp + shmem 缓存 x:
* - JKI * - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
...@@ -406,12 +413,11 @@ __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha, ...@@ -406,12 +413,11 @@ __global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
* - shmem 缓存 x,分块加载。 * - shmem 缓存 x,分块加载。
*/ */
template <bool USE_NTL = false, int TILE_K = 4096> template <bool USE_NTL = false, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha, __global__ void
const hip_bfloat16 *__restrict__ A, gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
int lda, const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x, const hip_bfloat16 *__restrict__ x, const float beta,
const float beta, // 0 hip_bfloat16 *__restrict__ y) {
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id; int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
...@@ -419,8 +425,9 @@ __global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha, ...@@ -419,8 +425,9 @@ __global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
// 缓存 x 的一个 tile // 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K]; __shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支 // 预先计算每一行的指针和原始 y 值
const hip_bfloat16 *row_ptr = A + m * lda; const hip_bfloat16 *row_ptr = A + m * lda; // 不需要分支
float y_original = (m < M) ? static_cast<float>(y[m]) : 0.0f;
float sum = 0.0f; float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile // 外层循环遍历 K 维度的所有 tile
...@@ -485,13 +492,13 @@ __global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha, ...@@ -485,13 +492,13 @@ __global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
// Lane 0 写回结果 // Lane 0 写回结果
if (lane_id == 0) { if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum); y[m] = hip_bfloat16(alpha * sum + beta * y_original);
} }
return; return;
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* 单线程 vec + warp + 主循环 unroll + shmem 缓存 x: * 单线程 vec + warp + 主循环 unroll + shmem 缓存 x:
* - JKI * - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
...@@ -502,7 +509,7 @@ __global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha, ...@@ -502,7 +509,7 @@ __global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
template <bool USE_NTL = false, int UNROLL = 4, int TILE_K = 4096> template <bool USE_NTL = false, int UNROLL = 4, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_unroll_shm( __global__ void gemv_bf16_TN_vec_warp_unroll_shm(
int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A, int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A,
int lda, const hip_bfloat16 *__restrict__ x, const float beta, // 0 int lda, const hip_bfloat16 *__restrict__ x, const float beta,
hip_bfloat16 *__restrict__ y) { hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
...@@ -511,8 +518,9 @@ __global__ void gemv_bf16_TN_vec_warp_unroll_shm( ...@@ -511,8 +518,9 @@ __global__ void gemv_bf16_TN_vec_warp_unroll_shm(
// 缓存 x 的一个 tile // 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K]; __shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支 // 预先计算每一行的指针和原始 y 值
const hip_bfloat16 *row_ptr = A + m * lda; const hip_bfloat16 *row_ptr = A + m * lda; // 不需要分支
float y_original = (m < M) ? static_cast<float>(y[m]) : 0.0f;
float sum = 0.0f; float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile // 外层循环遍历 K 维度的所有 tile
...@@ -603,13 +611,13 @@ __global__ void gemv_bf16_TN_vec_warp_unroll_shm( ...@@ -603,13 +611,13 @@ __global__ void gemv_bf16_TN_vec_warp_unroll_shm(
// Lane 0 写回结果 // Lane 0 写回结果
if (lane_id == 0) { if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum); y[m] = hip_bfloat16(alpha * sum + beta * y_original);
} }
return; return;
} }
/** y = alpha * A^T * x + 0 * y /** y = alpha * A^T * x + beta * y
* 单线程 vec + warp 处理多行 + shmem 缓存 x: * 单线程 vec + warp 处理多行 + shmem 缓存 x:
* - JKI * - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。 * - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
...@@ -618,12 +626,11 @@ __global__ void gemv_bf16_TN_vec_warp_unroll_shm( ...@@ -618,12 +626,11 @@ __global__ void gemv_bf16_TN_vec_warp_unroll_shm(
* - shmem 缓存 x,分块加载。 * - shmem 缓存 x,分块加载。
*/ */
template <bool USE_NTL = false, int TILE_K = 4096, int ROWS_PER_WARP = 2> template <bool USE_NTL = false, int TILE_K = 4096, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha, __global__ void
const hip_bfloat16 *__restrict__ A, gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha,
int lda, const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x, const hip_bfloat16 *__restrict__ x,
const float beta, // 0 const float beta, hip_bfloat16 *__restrict__ y) {
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
...@@ -637,13 +644,16 @@ __global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha, ...@@ -637,13 +644,16 @@ __global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha,
// 每个 lane 维护 ROWS_PER_WARP 个累加器 // 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP] = {0.0f}; float sum[ROWS_PER_WARP] = {0.0f};
// 预先计算每一行的指针 // 预先计算每一行的指针和原始 y 值
const hip_bfloat16 *row_ptr[ROWS_PER_WARP]; const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
float y_original[ROWS_PER_WARP];
#pragma unroll #pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) { for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r; int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支 // 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A; row_ptr[r] = (m < M) ? (A + m * lda) : A;
// 读取有效的原始 y 值
y_original[r] = (m < M) ? static_cast<float>(y[m]) : 0.0f;
} }
// 外层循环遍历 K 维度的所有 tile // 外层循环遍历 K 维度的所有 tile
...@@ -720,7 +730,7 @@ __global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha, ...@@ -720,7 +730,7 @@ __global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha,
if (lane_id == 0) { if (lane_id == 0) {
int m = m_base + r; int m = m_base + r;
if (m < M) { if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]); y[m] = hip_bfloat16(alpha * sum[r] + beta * y_original[r]);
} }
} }
} }
......
...@@ -111,9 +111,6 @@ inline void run_benchmark(int warmups, int loops, ...@@ -111,9 +111,6 @@ inline void run_benchmark(int warmups, int loops,
float alpha, const hip_bfloat16 *A, int lda, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, float beta, hip_bfloat16 *y, const hip_bfloat16 *x, float beta, hip_bfloat16 *y,
bool do_verify) { bool do_verify) {
std::cout << "GEMV Benchmarks" << std::endl;
hipEvent_t start, stop; hipEvent_t start, stop;
checkHipErrors(hipEventCreate(&start)); checkHipErrors(hipEventCreate(&start));
checkHipErrors(hipEventCreate(&stop)); checkHipErrors(hipEventCreate(&stop));
...@@ -141,12 +138,6 @@ inline void run_benchmark(int warmups, int loops, ...@@ -141,12 +138,6 @@ inline void run_benchmark(int warmups, int loops,
// 表头 // 表头
printf("%s\n", std::string(w_table, '-').c_str()); printf("%s\n", std::string(w_table, '-').c_str());
printf("M=%d, K=%d, N=1\n", M, K);
printf("lda=%d\n", lda);
printf("sizeof(A)=%lu MB\n", M * lda * sizeof(hip_bfloat16) / 1024 / 1024);
printf("L2 cache=%d MB\n", get_l2_cache_size());
printf("Warmups=%d, Loops=%d\n", warmups, loops);
printf("%s\n", std::string(w_table, '-').c_str());
printf("%-38s %10s %10s %10s %8s\n", "Kernel Name", "Time (us)", "GFLOPS", printf("%-38s %10s %10s %10s %8s\n", "Kernel Name", "Time (us)", "GFLOPS",
"BW (GB/s)", "Result"); "BW (GB/s)", "Result");
...@@ -195,9 +186,9 @@ inline void run_benchmark(int warmups, int loops, ...@@ -195,9 +186,9 @@ inline void run_benchmark(int warmups, int loops,
// 4. Metrics // 4. Metrics
double gflops = (2.0 * M * K) / (avg_ms * 1e-3) / 1e9; double gflops = (2.0 * M * K) / (avg_ms * 1e-3) / 1e9;
// Bandwidth = Read A + Read x + Write y // Bandwidth = Read A + Read x + Read y + Write y
// A: M*K, x: K, y: M // A: M*K, x: K, y: M
double bytes_moved = (double)(M * K + K + M) * sizeof(hip_bfloat16); double bytes_moved = (double)(M * K + K + M + M) * sizeof(hip_bfloat16);
double bw = bytes_moved / (avg_ms * 1e-3) / 1e9; double bw = bytes_moved / (avg_ms * 1e-3) / 1e9;
printf("%-38s %10.1f %10.2f %10.2f %8s\n", k.name.c_str(), avg_ms * 1e3, printf("%-38s %10.1f %10.2f %10.2f %8s\n", k.name.c_str(), avg_ms * 1e3,
......
...@@ -7,10 +7,10 @@ int main(int argc, char **argv) { ...@@ -7,10 +7,10 @@ int main(int argc, char **argv) {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
int M = 11264; int M = 11264;
int K = 4096;
// int N = 1; // Unused // int N = 1; // Unused
int K = 4096;
int lda = K; int lda = K;
int block_size = 256; int block_size = 128;
if (char *value = getCmdOption(argv, argv + argc, "--warmups")) { if (char *value = getCmdOption(argv, argv + argc, "--warmups")) {
warmups = std::stoi(value); warmups = std::stoi(value);
...@@ -28,6 +28,10 @@ int main(int argc, char **argv) { ...@@ -28,6 +28,10 @@ int main(int argc, char **argv) {
alpha = std::stof(value); alpha = std::stof(value);
} }
if (char *value = getCmdOption(argv, argv + argc, "--beta")) {
beta = std::stof(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-M")) { if (char *value = getCmdOption(argv, argv + argc, "-M")) {
M = std::stoi(value); M = std::stoi(value);
} }
...@@ -82,7 +86,7 @@ int main(int argc, char **argv) { ...@@ -82,7 +86,7 @@ int main(int argc, char **argv) {
constexpr int UNROLL = 4; constexpr int UNROLL = 4;
constexpr int ROWS_PER_WARP = 2; constexpr int ROWS_PER_WARP = 2;
#if defined(__HIP_PLATFORM_AMD__) #if defined(__HIP_PLATFORM_AMD__)
constexpr int TILE_K = calculate_tile_k<8>(4); constexpr int TILE_K = calculate_tile_k<8>(8);
#else #else
constexpr int TILE_K = calculate_tile_k<8>(1); constexpr int TILE_K = calculate_tile_k<8>(1);
#endif #endif
...@@ -253,6 +257,15 @@ int main(int argc, char **argv) { ...@@ -253,6 +257,15 @@ int main(int argc, char **argv) {
<<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y); <<<grid, block_size>>>(M, K, alpha, A, lda, x, beta, y);
}}); }});
// 打印信息
printf("GEMV Benchmarks: y = alpha * A^T * x + beta * y\n");
printf("Block size=%d\n", block_size);
printf("alpha=%.2f, beta=%.2f\n", alpha, beta);
printf("M=%d, N=1, K=%d, lda=%d\n", M, K, lda);
printf("sizeof(A)=%lu MB\n", M * lda * sizeof(hip_bfloat16) / 1024 / 1024);
printf("L2 cache=%d MB\n", get_l2_cache_size());
printf("Warmups=%d, Loops=%d\n", warmups, loops);
// 运行所有测试 // 运行所有测试
run_benchmark(warmups, loops, kernels, M, K, alpha, d_A, lda, d_x, beta, d_y, run_benchmark(warmups, loops, kernels, M, K, alpha, d_A, lda, d_x, beta, d_y,
do_verify); do_verify);
......
...@@ -7,21 +7,27 @@ BIND_CMD="numactl -N 0 -m 0" ...@@ -7,21 +7,27 @@ BIND_CMD="numactl -N 0 -m 0"
make clean make clean
CXX=hipcc make GPU_ARCH=gfx936 CXX=hipcc make GPU_ARCH=gfx936
# CXX=nvcc make GPU_ARCH=sm_80 # CXX=nvcc make GPU_ARCH=sm_80
W1="--verify 1 -M 11264 -K 4096 --alpha 1 --beta 0 -B 128"
W2="--verify 1 -M 4096 -K 11264 --alpha 1 --beta 0 -B 128"
W3="--verify 1 -M 12288 -K 4096 --alpha 1 --beta 0 -B 128"
W4="--verify 1 -M 4096 -K 4096 --alpha 1 --beta 1 -B 128"
if [[ "$*" == *"--pmc"* ]]; then if [[ "$*" == *"--pmc"* ]]; then
PROF_CMD="hipprof --trace-off --pmc" PROF_CMD="hipprof --trace-off --pmc"
${PROF_CMD} -o log/pmc-w1 ${BIND_CMD} ./gemv_bench --warmups 10 --loops 20 --verify 1 -M 11264 -K 4096 ${PROF_CMD} -o log/pmc-w1 ${BIND_CMD} ./gemv_bench --warmups 10 --loops 20 ${W1}
${PROF_CMD} -o log/pmc-w2 ${BIND_CMD} ./gemv_bench --warmups 10 --loops 20 --verify 1 -M 4096 -K 11264 ${PROF_CMD} -o log/pmc-w2 ${BIND_CMD} ./gemv_bench --warmups 10 --loops 20 ${W2}
${PROF_CMD} -o log/pmc-w3 ${BIND_CMD} ./gemv_bench --warmups 10 --loops 20 --verify 1 -M 12288 -K 4096 ${PROF_CMD} -o log/pmc-w3 ${BIND_CMD} ./gemv_bench --warmups 10 --loops 20 ${W3}
${PROF_CMD} -o log/pmc-w4 ${BIND_CMD} ./gemv_bench --warmups 10 --loops 20 --verify 1 -M 4096 -K 4096 ${PROF_CMD} -o log/pmc-w4 ${BIND_CMD} ./gemv_bench --warmups 10 --loops 20 ${W4}
elif [[ "$*" == *"--trace"* ]]; then elif [[ "$*" == *"--trace"* ]]; then
PROF_CMD="hipprof --hip-trace" PROF_CMD="hipprof --hip-trace"
${PROF_CMD} -o log/trace-w1 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 1000 --verify 1 -M 11264 -K 4096 ${PROF_CMD} -o log/trace-w1 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 1000 ${W1}
${PROF_CMD} -o log/trace-w2 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 1000 --verify 1 -M 4096 -K 11264 ${PROF_CMD} -o log/trace-w2 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 1000 ${W2}
${PROF_CMD} -o log/trace-w3 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 1000 --verify 1 -M 12288 -K 4096 ${PROF_CMD} -o log/trace-w3 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 1000 ${W3}
${PROF_CMD} -o log/trace-w4 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 1000 --verify 1 -M 4096 -K 4096 ${PROF_CMD} -o log/trace-w4 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 1000 ${W4}
else else
${BIND_CMD} ./gemv_bench --warmups 100 --loops 2000 --verify 1 -M 11264 -K 4096 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 2000 ${W1}
${BIND_CMD} ./gemv_bench --warmups 100 --loops 2000 --verify 1 -M 4096 -K 11264 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 2000 ${W2}
${BIND_CMD} ./gemv_bench --warmups 100 --loops 2000 --verify 1 -M 12288 -K 4096 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 2000 ${W3}
${BIND_CMD} ./gemv_bench --warmups 100 --loops 2000 --verify 1 -M 4096 -K 4096 ${BIND_CMD} ./gemv_bench --warmups 100 --loops 2000 ${W4}
fi fi
#!/bin/bash
chmod u+x /opt/dtk/lib/rocblas/benchmark_tool/*
export PATH=/opt/dtk/lib/rocblas/benchmark_tool/:${PATH}
BIND_CMD="numactl -m 0 -N 0"
BATCH_SIZE=1
export HIP_VISIBLE_DEVICES=1
# export ROCBLAS_TENSILE_GEMM_OVERRIDE_PATH=$(PWD)/tensil_gemms.csv
W1="-f gemm_ex --transposeA T --transposeB N -m 11264 -n ${BATCH_SIZE} -k 4096 --alpha 1 --a_type bf16_r --lda 4096 --b_type bf16_r --ldb 4096 --beta 0 --c_type bf16_r --ldc 11264 --d_type bf16_r --ldd 11264 --compute_type f32_r --algo 0 --solution_index 0 --flags 0"
W2="-f gemm_ex --transposeA T --transposeB N -m 4096 -n ${BATCH_SIZE} -k 11264 --alpha 1 --a_type bf16_r --lda 11264 --b_type bf16_r --ldb 11264 --beta 0 --c_type bf16_r --ldc 4096 --d_type bf16_r --ldd 4096 --compute_type f32_r --algo 0 --solution_index 0 --flags 0"
W3="-f gemm_ex --transposeA T --transposeB N -m 12288 -n ${BATCH_SIZE} -k 4096 --alpha 1 --a_type bf16_r --lda 4096 --b_type bf16_r --ldb 4096 --beta 0 --c_type bf16_r --ldc 12288 --d_type bf16_r --ldd 12288 --compute_type f32_r --algo 0 --solution_index 0 --flags 0"
W4="-f gemm_ex --transposeA T --transposeB N -m 4096 -n ${BATCH_SIZE} -k 4096 --alpha 1 --a_type bf16_r --lda 4096 --b_type bf16_r --ldb 4096 --beta 1 --c_type bf16_r --ldc 4096 --d_type bf16_r --ldd 4096 --compute_type f32_r --algo 0 --solution_index 0 --flags 0"
if [[ "$*" == *"--pmc"* ]]; then
PROF_CMD="hipprof --trace-off --pmc"
${PROF_CMD} -o log/pmc-blas-w1-bs${BATCH_SIZE} ${BIND_CMD} rocblas-bench ${W1}
${PROF_CMD} -o log/pmc-blas-w2-bs${BATCH_SIZE} ${BIND_CMD} rocblas-bench ${W2}
${PROF_CMD} -o log/pmc-blas-w3-bs${BATCH_SIZE} ${BIND_CMD} rocblas-bench ${W3}
${PROF_CMD} -o log/pmc-blas-w4-bs${BATCH_SIZE} ${BIND_CMD} rocblas-bench ${W4}
elif [[ "$*" == *"--trace"* ]]; then
PROF_CMD="hipprof --hip-trace"
${PROF_CMD} -o log/trace-blas-w1-bs${BATCH_SIZE} ${BIND_CMD} rocblas-bench ${W1}
${PROF_CMD} -o log/trace-blas-w2-bs${BATCH_SIZE} ${BIND_CMD} rocblas-bench ${W2}
${PROF_CMD} -o log/trace-blas-w3-bs${BATCH_SIZE} ${BIND_CMD} rocblas-bench ${W3}
${PROF_CMD} -o log/trace-blas-w4-bs${BATCH_SIZE} ${BIND_CMD} rocblas-bench ${W4}
else
${BIND_CMD} rocblas-bench ${W1}
${BIND_CMD} rocblas-bench ${W2}
${BIND_CMD} rocblas-bench ${W3}
${BIND_CMD} rocblas-bench ${W4}
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment