gemv_export.cpp 978 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include "gemv_bf16.h"

extern "C" {

/** y = alpha * A^T * x + beta * y
 *
 * @param d_A: input matrix A
 * @param M: number of rows of A
 * @param K: number of columns of A
 * @param lda: leading dimension of A
 * @param d_x: input vector x
 * @param d_y: output vector y
 * @param alpha: scaling factor for A^T * x
 * @param beta: scaling factor for y
 */
void gemv_bf16_TN_vec_warp_unroll_ntl(hip_bfloat16 *d_A, int M, int K, int lda,
                                      hip_bfloat16 *d_x, hip_bfloat16 *d_y,
                                      float alpha, float beta) {

  constexpr bool USE_NTL = true;
  constexpr int UNROLL = 4;

  int block_size = 128;
  int warps_per_block = block_size / WARP_SIZE;
  int grid = (M + warps_per_block - 1) / warps_per_block;

  dim3 grid_dim(grid);
  dim3 block_dim(block_size);

  gemv_bf16_TN_vec_warp_unroll<USE_NTL, UNROLL>
      <<<grid_dim, block_dim>>>(M, K, alpha, d_A, lda, d_x, beta, d_y);

  return;
}

} // extern "C"