#include "gemv_bf16.h" extern "C" { /** y = alpha * A^T * x + beta * y * * @param d_A: input matrix A * @param M: number of rows of A * @param K: number of columns of A * @param lda: leading dimension of A * @param d_x: input vector x * @param d_y: output vector y * @param alpha: scaling factor for A^T * x * @param beta: scaling factor for y */ void gemv_bf16_TN_vec_warp_unroll_ntl(hip_bfloat16 *d_A, int M, int K, int lda, hip_bfloat16 *d_x, hip_bfloat16 *d_y, float alpha, float beta) { constexpr bool USE_NTL = true; constexpr int UNROLL = 4; int block_size = 128; int warps_per_block = block_size / WARP_SIZE; int grid = (M + warps_per_block - 1) / warps_per_block; dim3 grid_dim(grid); dim3 block_dim(block_size); gemv_bf16_TN_vec_warp_unroll <<>>(M, K, alpha, d_A, lda, d_x, beta, d_y); return; } } // extern "C"