Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tsoc
hg-misc-tools
Commits
b3a56179
Commit
b3a56179
authored
Jan 28, 2026
by
one
Browse files
Update GEMV benchmarks
parent
977247a7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
984 additions
and
936 deletions
+984
-936
evo2/bm/Makefile
evo2/bm/Makefile
+1
-1
evo2/bm/gemv_bf16.cpp
evo2/bm/gemv_bf16.cpp
+955
-0
evo2/bm/gemv_bf16.hip
evo2/bm/gemv_bf16.hip
+0
-916
evo2/bm/gemv_utils.h
evo2/bm/gemv_utils.h
+13
-12
evo2/bm/run-all.sh
evo2/bm/run-all.sh
+15
-7
No files found.
evo2/bm/Makefile
View file @
b3a56179
...
...
@@ -3,7 +3,7 @@ CXXFLAGS ?= -std=c++17 -O3
OFFLOAD_ARCH
?=
gfx936
TARGET
:=
gemv_bench
SRC
:=
gemv_bf16.
hi
p
SRC
:=
gemv_bf16.
cp
p
DEP
:=
gemv_utils.h
.PHONY
:
all clean
...
...
evo2/bm/gemv_bf16.cpp
0 → 100644
View file @
b3a56179
/**
* 模仿 GEMM 接口的 GEMV,即 N=1。
* 编译:
* hipcc -std=c++17 -O3 --offload-arch=gfx936 gemv_bf16.hip -o gemv_bench
* 执行:
* HIP_VISIBLE_DEVICES=4 numactl -N 0 -m 0 ./gemv_bench -M 11264 -K 4096
*/
#include "gemv_utils.h"
#define WARP_SIZE 64
#define VEC_WIDTH 8
#define OFFSET(i, j, lda) ((i) + (j) * (lda))
#define OFFSET_T(i, j, lda) ((i) * (lda) + (j))
/**
* 根据需求的并发 block 数量计算 shmem 用量(即 TILE_K 指定的 BF16 元素个数)
*
* AlignElements 为对齐粒度,即元素个数,默认 128-bit 对齐。
* - 8: 对齐到 128-bit (可能有利于 load128b)
* - 16: 对齐到 256-bit (某些 MFMA 指令需求)
*/
template
<
int
AlignElements
=
8
>
constexpr
int
calculate_tile_k
(
int
concurrent_blocks
)
{
// 安全检查
if
(
concurrent_blocks
<
1
)
concurrent_blocks
=
1
;
// 直接切分 LDS
constexpr
int
MAX_LDS_BYTES_PER_CU
=
65536
;
int
bytes_per_block
=
MAX_LDS_BYTES_PER_CU
/
concurrent_blocks
;
// 转为元素个数
int
max_elements
=
bytes_per_block
/
sizeof
(
hip_bfloat16
);
// 对齐
return
(
max_elements
/
AlignElements
)
*
AlignElements
;
}
/// 辅助结构体:把 float4 (128位) 重新解释为 8 个 bf16
struct
__align__
(
16
)
bf16_x8
{
hip_bfloat16
vals
[
VEC_WIDTH
];
};
/// 替代 float4
typedef
float
__attribute__
((
ext_vector_type
(
4
)))
float4_native
;
/// 128-bit non-temporal load 或者 cached load
template
<
bool
USE_NTL
=
false
>
__device__
__forceinline__
bf16_x8
load_128b
(
const
hip_bfloat16
*
src
)
{
if
constexpr
(
USE_NTL
)
{
// 把地址转换为 float4_native 指针
const
float4_native
*
ptr
=
reinterpret_cast
<
const
float4_native
*>
(
src
);
// 使用 Clang 内置 non-temporal load 函数,生成带有 slc/nt 修饰符的加载指令
float4_native
tmp
=
__builtin_nontemporal_load
(
ptr
);
// 把加载到的 128 位数据重新解释为 bf16_x8
return
*
reinterpret_cast
<
bf16_x8
*>
(
&
tmp
);
}
else
{
return
*
reinterpret_cast
<
const
bf16_x8
*>
(
src
);
}
}
/** y = alpha * A^T * x + 0 * y
* Naive 实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代
*/
__global__
void
gemv_bf16_TN_naive
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
m
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// output
if
(
m
>=
M
)
return
;
const
hip_bfloat16
*
row_ptr
=
A
+
m
*
lda
;
float
sum
=
0.0
f
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
float
val_a
=
static_cast
<
float
>
(
row_ptr
[
k
]);
float
val_x
=
static_cast
<
float
>
(
x
[
k
]);
sum
+=
val_a
*
val_x
;
}
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
);
return
;
}
/** y = alpha * A^T * x + 0 * y
* 向量化实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代。
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
*/
template
<
bool
USE_NTL
=
false
>
__global__
void
gemv_bf16_TN_vec
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
m
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// output
if
(
m
>=
M
)
return
;
const
hip_bfloat16
*
row_ptr
=
A
+
m
*
lda
;
float
sum
=
0.0
f
;
// 每次读 VEC_WIDTH 个数据
for
(
int
k
=
0
;
k
<
K
;
k
+=
VEC_WIDTH
)
{
bf16_x8
a_vec
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
k
]);
bf16_x8
x_vec
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x
[
k
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
+=
static_cast
<
float
>
(
a_vec
.
vals
[
i
])
*
static_cast
<
float
>
(
x_vec
.
vals
[
i
]);
}
}
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
);
return
;
}
/** y = alpha * A^T * x + 0 * y
* Warp 归约:
* - JKI
* - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。
* - Warp 内归约。
*/
__global__
void
gemv_bf16_TN_warp
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int
m
=
blockIdx
.
x
*
(
blockDim
.
x
/
WARP_SIZE
)
+
warp_id
;
if
(
m
>=
M
)
return
;
const
int
stride
=
WARP_SIZE
;
const
hip_bfloat16
*
row_ptr
=
A
+
m
*
lda
;
float
sum
=
0.0
f
;
for
(
int
k
=
lane_id
;
k
<
K
;
k
+=
stride
)
{
float
val_a
=
static_cast
<
float
>
(
row_ptr
[
k
]);
float
val_x
=
static_cast
<
float
>
(
x
[
k
]);
sum
+=
val_a
*
val_x
;
}
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down
(
sum
,
offset
);
}
// Lane 0 负责写回
if
(
lane_id
==
0
)
{
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
);
}
return
;
}
/** y = alpha * A^T * x + 0 * y
* Vec + warp:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
*/
template
<
bool
USE_NTL
=
false
>
__global__
void
gemv_bf16_TN_vec_warp
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int
m
=
blockIdx
.
x
*
(
blockDim
.
x
/
WARP_SIZE
)
+
warp_id
;
if
(
m
>=
M
)
return
;
const
int
stride
=
WARP_SIZE
*
VEC_WIDTH
;
const
hip_bfloat16
*
row_ptr
=
A
+
m
*
lda
;
float
sum
=
0.0
f
;
for
(
int
k
=
lane_id
*
VEC_WIDTH
;
k
<
K
;
k
+=
stride
)
{
bf16_x8
a_vec
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
k
]);
bf16_x8
x_vec
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x
[
k
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
+=
static_cast
<
float
>
(
a_vec
.
vals
[
i
])
*
static_cast
<
float
>
(
x_vec
.
vals
[
i
]);
}
}
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down
(
sum
,
offset
);
}
// Lane 0 负责写回
if
(
lane_id
==
0
)
{
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
);
}
return
;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
*/
template
<
bool
USE_NTL
=
false
,
int
ROWS_PER_WARP
=
2
>
__global__
void
gemv_bf16_TN_vec_warp_mr
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
// 每个 warp 处理 ROWS_PER_WARP 行
int
m_base
=
blockIdx
.
x
*
(
blockDim
.
x
/
WARP_SIZE
)
*
ROWS_PER_WARP
+
warp_id
*
ROWS_PER_WARP
;
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float
sum
[
ROWS_PER_WARP
]
=
{
0.0
f
};
// 预先计算每一行的指针
const
hip_bfloat16
*
row_ptr
[
ROWS_PER_WARP
];
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
int
m
=
m_base
+
r
;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr
[
r
]
=
(
m
<
M
)
?
(
A
+
m
*
lda
)
:
A
;
}
const
int
stride
=
WARP_SIZE
*
VEC_WIDTH
;
for
(
int
k
=
lane_id
*
VEC_WIDTH
;
k
<
K
;
k
+=
stride
)
{
bf16_x8
x_vec
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x
[
k
]);
bf16_x8
a_vecs
[
ROWS_PER_WARP
];
// 批量加载,无分支
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
a_vecs
[
r
]
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
r
][
k
]);
}
// 批量计算
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
[
r
]
+=
static_cast
<
float
>
(
a_vecs
[
r
].
vals
[
i
])
*
static_cast
<
float
>
(
x_vec
.
vals
[
i
]);
}
}
}
// Warp 内归约(每行独立归约)
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
[
r
]
+=
__shfl_down
(
sum
[
r
],
offset
);
}
// Lane 0 写回结果
if
(
lane_id
==
0
)
{
int
m
=
m_base
+
r
;
if
(
m
<
M
)
{
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
[
r
]);
}
}
}
return
;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
*/
template
<
bool
USE_NTL
=
false
,
int
UNROLL
=
4
>
__global__
void
gemv_bf16_TN_vec_warp_unroll
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int
m
=
blockIdx
.
x
*
(
blockDim
.
x
/
WARP_SIZE
)
+
warp_id
;
if
(
m
>=
M
)
return
;
const
int
stride
=
WARP_SIZE
*
VEC_WIDTH
*
UNROLL
;
const
hip_bfloat16
*
row_ptr
=
A
+
m
*
lda
;
float
sum
=
0.0
f
;
// 主循环临时变量
bf16_x8
a_frag
[
UNROLL
];
bf16_x8
x_frag
[
UNROLL
];
int
k0
=
lane_id
*
VEC_WIDTH
;
int
k
=
0
;
// 主循环
for
(;
k
<=
K
-
stride
;
k
+=
stride
)
{
#pragma unroll
for
(
int
u
=
0
;
u
<
UNROLL
;
++
u
)
{
int
offset
=
k
+
k0
+
u
*
(
WARP_SIZE
*
VEC_WIDTH
);
a_frag
[
u
]
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
offset
]);
x_frag
[
u
]
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x
[
offset
]);
}
#pragma unroll
for
(
int
u
=
0
;
u
<
UNROLL
;
++
u
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
+=
static_cast
<
float
>
(
a_frag
[
u
].
vals
[
i
])
*
static_cast
<
float
>
(
x_frag
[
u
].
vals
[
i
]);
}
}
}
// Tail 循环
for
(;
k
<
K
;
k
+=
WARP_SIZE
*
VEC_WIDTH
)
{
int
offset
=
k
+
k0
;
if
(
offset
>=
K
)
continue
;
bf16_x8
a_vec
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
offset
]);
bf16_x8
x_vec
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x
[
offset
]);
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
+=
static_cast
<
float
>
(
a_vec
.
vals
[
i
])
*
static_cast
<
float
>
(
x_vec
.
vals
[
i
]);
}
}
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down
(
sum
,
offset
);
}
// Lane 0 负责写回
if
(
lane_id
==
0
)
{
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
);
}
return
;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - shmem 缓存 x,分块加载。
*/
template
<
bool
USE_NTL
=
false
,
int
TILE_K
=
4096
>
__global__
void
gemv_bf16_TN_vec_warp_shm
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int
m
=
blockIdx
.
x
*
(
blockDim
.
x
/
WARP_SIZE
)
+
warp_id
;
// 缓存 x 的一个 tile
__shared__
hip_bfloat16
x_tile
[
TILE_K
];
// 不会在 m>=M 时访问 A,因此不需要分支
const
hip_bfloat16
*
row_ptr
=
A
+
m
*
lda
;
float
sum
=
0.0
f
;
// 外层循环遍历 K 维度的所有 tile
for
(
int
kk
=
0
;
kk
<
K
;
kk
+=
TILE_K
)
{
int
tile_size
=
min
(
TILE_K
,
K
-
kk
);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for
(
int
i
=
threadIdx
.
x
*
VEC_WIDTH
;
i
<
tile_size
;
i
+=
blockDim
.
x
*
VEC_WIDTH
)
{
if
(
i
+
VEC_WIDTH
<=
tile_size
)
{
// 完整的向量化加载
*
reinterpret_cast
<
bf16_x8
*>
(
&
x_tile
[
i
])
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x
[
kk
+
i
]);
}
else
{
// Tail 循环逐个加载
for
(
int
j
=
0
;
j
<
VEC_WIDTH
&&
i
+
j
<
tile_size
;
++
j
)
{
x_tile
[
i
+
j
]
=
x
[
kk
+
i
+
j
];
}
}
}
__syncthreads
();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if
(
m
<
M
)
{
const
int
stride
=
WARP_SIZE
*
VEC_WIDTH
;
for
(
int
k
=
lane_id
*
VEC_WIDTH
;
k
<
tile_size
;
k
+=
stride
)
{
if
(
k
+
VEC_WIDTH
<=
tile_size
)
{
// 完整的向量化计算
bf16_x8
a_vec
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
kk
+
k
]);
bf16_x8
x_vec
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x_tile
[
k
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
+=
static_cast
<
float
>
(
a_vec
.
vals
[
i
])
*
static_cast
<
float
>
(
x_vec
.
vals
[
i
]);
}
}
else
{
// Tail 循环
for
(
int
i
=
0
;
i
<
VEC_WIDTH
&&
k
+
i
<
tile_size
;
++
i
)
{
float
val_a
=
static_cast
<
float
>
(
row_ptr
[
kk
+
k
+
i
]);
float
val_x
=
static_cast
<
float
>
(
x_tile
[
k
+
i
]);
sum
+=
val_a
*
val_x
;
}
}
}
}
__syncthreads
();
}
if
(
m
>=
M
)
return
;
// Warp 内归约
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down
(
sum
,
offset
);
}
// Lane 0 写回结果
if
(
lane_id
==
0
)
{
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
);
}
return
;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
* - shmem 缓存 x,分块加载。
*/
template
<
bool
USE_NTL
=
false
,
int
UNROLL
=
4
,
int
TILE_K
=
4096
>
__global__
void
gemv_bf16_TN_vec_warp_unroll_shm
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int
m
=
blockIdx
.
x
*
(
blockDim
.
x
/
WARP_SIZE
)
+
warp_id
;
// 缓存 x 的一个 tile
__shared__
hip_bfloat16
x_tile
[
TILE_K
];
// 不会在 m>=M 时访问 A,因此不需要分支
const
hip_bfloat16
*
row_ptr
=
A
+
m
*
lda
;
float
sum
=
0.0
f
;
// 外层循环遍历 K 维度的所有 tile
for
(
int
kk
=
0
;
kk
<
K
;
kk
+=
TILE_K
)
{
int
tile_size
=
min
(
TILE_K
,
K
-
kk
);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for
(
int
i
=
threadIdx
.
x
*
VEC_WIDTH
;
i
<
tile_size
;
i
+=
blockDim
.
x
*
VEC_WIDTH
)
{
if
(
i
+
VEC_WIDTH
<=
tile_size
)
{
// 完整的向量化加载
*
reinterpret_cast
<
bf16_x8
*>
(
&
x_tile
[
i
])
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x
[
kk
+
i
]);
}
else
{
// Tail 循环逐个加载
for
(
int
j
=
0
;
j
<
VEC_WIDTH
&&
i
+
j
<
tile_size
;
++
j
)
{
x_tile
[
i
+
j
]
=
x
[
kk
+
i
+
j
];
}
}
}
__syncthreads
();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if
(
m
<
M
)
{
const
int
warp_stride
=
WARP_SIZE
*
VEC_WIDTH
;
const
int
unroll_stride
=
warp_stride
*
UNROLL
;
int
k
=
lane_id
*
VEC_WIDTH
;
// 主循环:Unroll
for
(;
k
<=
tile_size
-
unroll_stride
;
k
+=
unroll_stride
)
{
bf16_x8
a_frag
[
UNROLL
];
bf16_x8
x_frag
[
UNROLL
];
#pragma unroll
for
(
int
u
=
0
;
u
<
UNROLL
;
++
u
)
{
int
current_k
=
k
+
u
*
warp_stride
;
a_frag
[
u
]
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
kk
+
current_k
]);
x_frag
[
u
]
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x_tile
[
current_k
]);
}
#pragma unroll
for
(
int
u
=
0
;
u
<
UNROLL
;
++
u
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
+=
static_cast
<
float
>
(
a_frag
[
u
].
vals
[
i
])
*
static_cast
<
float
>
(
x_frag
[
u
].
vals
[
i
]);
}
}
}
// Tail 循环
for
(;
k
<
tile_size
;
k
+=
warp_stride
)
{
if
(
k
+
VEC_WIDTH
<=
tile_size
)
{
// 完整的向量化计算
bf16_x8
a_vec
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
kk
+
k
]);
bf16_x8
x_vec
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x_tile
[
k
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
+=
static_cast
<
float
>
(
a_vec
.
vals
[
i
])
*
static_cast
<
float
>
(
x_vec
.
vals
[
i
]);
}
}
else
{
// Tail 循环
for
(
int
i
=
0
;
i
<
VEC_WIDTH
&&
k
+
i
<
tile_size
;
++
i
)
{
float
val_a
=
static_cast
<
float
>
(
row_ptr
[
kk
+
k
+
i
]);
float
val_x
=
static_cast
<
float
>
(
x_tile
[
k
+
i
]);
sum
+=
val_a
*
val_x
;
}
}
}
}
__syncthreads
();
}
if
(
m
>=
M
)
return
;
// Warp 内归约
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down
(
sum
,
offset
);
}
// Lane 0 写回结果
if
(
lane_id
==
0
)
{
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
);
}
return
;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行 + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
* - shmem 缓存 x,分块加载。
*/
template
<
bool
USE_NTL
=
false
,
int
TILE_K
=
4096
,
int
ROWS_PER_WARP
=
2
>
__global__
void
gemv_bf16_TN_vec_warp_mr_shm
(
int
M
,
int
K
,
const
float
alpha
,
const
hip_bfloat16
*
__restrict__
A
,
int
lda
,
const
hip_bfloat16
*
__restrict__
x
,
const
float
beta
,
// 0
hip_bfloat16
*
__restrict__
y
)
{
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
// 每个 warp 处理 ROWS_PER_WARP 行
int
m_base
=
blockIdx
.
x
*
(
blockDim
.
x
/
WARP_SIZE
)
*
ROWS_PER_WARP
+
warp_id
*
ROWS_PER_WARP
;
// 缓存 x 的一个 tile
__shared__
hip_bfloat16
x_tile
[
TILE_K
];
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float
sum
[
ROWS_PER_WARP
]
=
{
0.0
f
};
// 预先计算每一行的指针
const
hip_bfloat16
*
row_ptr
[
ROWS_PER_WARP
];
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
int
m
=
m_base
+
r
;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr
[
r
]
=
(
m
<
M
)
?
(
A
+
m
*
lda
)
:
A
;
}
// 外层循环遍历 K 维度的所有 tile
for
(
int
kk
=
0
;
kk
<
K
;
kk
+=
TILE_K
)
{
int
tile_size
=
min
(
TILE_K
,
K
-
kk
);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
for
(
int
i
=
threadIdx
.
x
*
VEC_WIDTH
;
i
<
tile_size
;
i
+=
blockDim
.
x
*
VEC_WIDTH
)
{
if
(
i
+
VEC_WIDTH
<=
tile_size
)
{
// 完整的向量化加载
*
reinterpret_cast
<
bf16_x8
*>
(
&
x_tile
[
i
])
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x
[
kk
+
i
]);
}
else
{
// Tail 循环逐个加载
for
(
int
j
=
0
;
j
<
VEC_WIDTH
&&
i
+
j
<
tile_size
;
++
j
)
{
x_tile
[
i
+
j
]
=
x
[
kk
+
i
+
j
];
}
}
}
__syncthreads
();
// Step 2: 计算当前 tile 的贡献
// 每个 lane 处理 ROWS_PER_WARP 行
const
int
stride
=
WARP_SIZE
*
VEC_WIDTH
;
for
(
int
k
=
lane_id
*
VEC_WIDTH
;
k
<
tile_size
;
k
+=
stride
)
{
if
(
k
+
VEC_WIDTH
<=
tile_size
)
{
// 完整的向量化计算
bf16_x8
x_vec
=
*
reinterpret_cast
<
const
bf16_x8
*>
(
&
x_tile
[
k
]);
bf16_x8
a_vecs
[
ROWS_PER_WARP
];
// 批量加载,无分支
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
a_vecs
[
r
]
=
load_128b
<
USE_NTL
>
(
&
row_ptr
[
r
][
kk
+
k
]);
}
// 批量计算
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_WIDTH
;
++
i
)
{
sum
[
r
]
+=
static_cast
<
float
>
(
a_vecs
[
r
].
vals
[
i
])
*
static_cast
<
float
>
(
x_vec
.
vals
[
i
]);
}
}
}
else
{
// Tail 循环
for
(
int
i
=
0
;
i
<
VEC_WIDTH
&&
k
+
i
<
tile_size
;
++
i
)
{
float
val_x
=
static_cast
<
float
>
(
x_tile
[
k
+
i
]);
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
float
val_a
=
static_cast
<
float
>
(
row_ptr
[
r
][
kk
+
k
+
i
]);
sum
[
r
]
+=
val_a
*
val_x
;
}
}
}
}
__syncthreads
();
}
// Warp 内归约(每行独立归约)
#pragma unroll
for
(
int
r
=
0
;
r
<
ROWS_PER_WARP
;
++
r
)
{
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
[
r
]
+=
__shfl_down
(
sum
[
r
],
offset
);
}
// Lane 0 写回结果
if
(
lane_id
==
0
)
{
int
m
=
m_base
+
r
;
if
(
m
<
M
)
{
y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
[
r
]);
}
}
}
return
;
}
/// GEMV Microbenchmarks
/// y = alpha * A^T * x + beta * y
/// M = 输出维度 (11264)
/// K = 归约维度 (4096)
/// N = 1
int
main
(
int
argc
,
char
**
argv
)
{
bool
do_verify
=
false
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
int
M
=
11264
;
int
K
=
4096
;
// int N = 1; // Unused
int
lda
=
K
;
int
block_size
=
256
;
if
(
char
*
value
=
getCmdOption
(
argv
,
argv
+
argc
,
"--verify"
))
{
do_verify
=
std
::
stoi
(
value
)
==
1
;
}
if
(
char
*
value
=
getCmdOption
(
argv
,
argv
+
argc
,
"--alpha"
))
{
alpha
=
std
::
stof
(
value
);
}
if
(
char
*
value
=
getCmdOption
(
argv
,
argv
+
argc
,
"-M"
))
{
M
=
std
::
stoi
(
value
);
}
if
(
char
*
value
=
getCmdOption
(
argv
,
argv
+
argc
,
"-K"
))
{
K
=
std
::
stoi
(
value
);
lda
=
K
;
}
if
(
char
*
value
=
getCmdOption
(
argv
,
argv
+
argc
,
"--lda"
))
{
lda
=
std
::
stoi
(
value
);
}
if
(
char
*
value
=
getCmdOption
(
argv
,
argv
+
argc
,
"-B"
))
{
block_size
=
std
::
stoi
(
value
);
}
// transA=T,因此是行优先
size_t
count_A
=
(
size_t
)
M
*
lda
;
size_t
size_A
=
count_A
*
sizeof
(
hip_bfloat16
);
size_t
size_x
=
(
size_t
)
K
*
sizeof
(
hip_bfloat16
);
size_t
size_y
=
(
size_t
)
M
*
sizeof
(
hip_bfloat16
);
// Host 内存分配
std
::
vector
<
hip_bfloat16
>
h_A
(
count_A
);
std
::
vector
<
hip_bfloat16
>
h_x
(
K
);
std
::
vector
<
hip_bfloat16
>
h_y
(
M
);
// 随机初始数据
const
float
rand_max
=
static_cast
<
float
>
(
RAND_MAX
);
for
(
int
i
=
0
;
i
<
count_A
;
i
++
)
h_A
[
i
]
=
hip_bfloat16
(
static_cast
<
float
>
(
rand
())
/
rand_max
);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
h_x
[
i
]
=
hip_bfloat16
(
static_cast
<
float
>
(
rand
())
/
rand_max
);
for
(
int
i
=
0
;
i
<
M
;
i
++
)
h_y
[
i
]
=
hip_bfloat16
(
0.0
f
);
// Device 内存分配
hip_bfloat16
*
d_A
,
*
d_x
,
*
d_y
;
checkHipErrors
(
hipMalloc
(
&
d_A
,
size_A
));
checkHipErrors
(
hipMalloc
(
&
d_x
,
size_x
));
checkHipErrors
(
hipMalloc
(
&
d_y
,
size_y
));
checkHipErrors
(
hipMemcpy
(
d_A
,
h_A
.
data
(),
size_A
,
hipMemcpyHostToDevice
));
checkHipErrors
(
hipMemcpy
(
d_x
,
h_x
.
data
(),
size_x
,
hipMemcpyHostToDevice
));
checkHipErrors
(
hipMemcpy
(
d_y
,
h_y
.
data
(),
size_y
,
hipMemcpyHostToDevice
));
// Kernel 注册表
std
::
vector
<
KernelCase
>
kernels
;
constexpr
bool
NTL
=
true
;
constexpr
int
UNROLL
=
4
;
constexpr
int
TILE_K
=
calculate_tile_k
<
8
>
(
4
);
constexpr
int
ROWS_PER_WARP
=
2
;
kernels
.
push_back
(
{
"naive"
,
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
grid
=
(
M
+
block_size
-
1
)
/
block_size
;
gemv_bf16_TN_naive
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8"
,
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
grid
=
(
M
+
block_size
-
1
)
/
block_size
;
gemv_bf16_TN_vec
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8_ntl"
,
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
grid
=
(
M
+
block_size
-
1
)
/
block_size
;
gemv_bf16_TN_vec
<
NTL
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"warp"
,
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_warp
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8+warp"
,
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8_ntl+warp"
,
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp
<
NTL
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8+warp_mr"
+
std
::
to_string
(
ROWS_PER_WARP
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
((
M
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_mr
<!
NTL
,
ROWS_PER_WARP
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8_ntl+warp_mr"
+
std
::
to_string
(
ROWS_PER_WARP
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
((
M
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_mr
<
NTL
,
ROWS_PER_WARP
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8+warp+unroll"
+
std
::
to_string
(
UNROLL
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_unroll
<!
NTL
,
UNROLL
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8_ntl+warp+unroll"
+
std
::
to_string
(
UNROLL
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_unroll
<
NTL
,
UNROLL
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8+warp+shm"
+
std
::
to_string
(
TILE_K
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_shm
<!
NTL
,
TILE_K
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8_ntl+warp+shm"
+
std
::
to_string
(
TILE_K
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_shm
<
NTL
,
TILE_K
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8+warp+unroll"
+
std
::
to_string
(
UNROLL
)
+
"+shm"
+
std
::
to_string
(
TILE_K
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_unroll_shm
<!
NTL
,
UNROLL
,
TILE_K
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8_ntl+warp+unroll"
+
std
::
to_string
(
UNROLL
)
+
"+shm"
+
std
::
to_string
(
TILE_K
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
(
M
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_unroll_shm
<
NTL
,
UNROLL
,
TILE_K
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8+warp_mr"
+
std
::
to_string
(
ROWS_PER_WARP
)
+
"+shm"
+
std
::
to_string
(
TILE_K
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
((
M
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_mr_shm
<!
NTL
,
TILE_K
,
ROWS_PER_WARP
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
kernels
.
push_back
(
{
"vec8_ntl+warp_mr"
+
std
::
to_string
(
ROWS_PER_WARP
)
+
"+shm"
+
std
::
to_string
(
TILE_K
),
[
&
](
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
{
int
warps_per_block
=
block_size
/
WARP_SIZE
;
int
grid
=
((
M
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
+
warps_per_block
-
1
)
/
warps_per_block
;
gemv_bf16_TN_vec_warp_mr_shm
<
NTL
,
TILE_K
,
ROWS_PER_WARP
>
<<<
grid
,
block_size
>>>
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}});
// 运行所有测试
run_benchmark
(
kernels
,
M
,
K
,
alpha
,
d_A
,
lda
,
d_x
,
beta
,
d_y
,
do_verify
);
// 清理
checkHipErrors
(
hipFree
(
d_A
));
checkHipErrors
(
hipFree
(
d_x
));
checkHipErrors
(
hipFree
(
d_y
));
return
0
;
}
\ No newline at end of file
evo2/bm/gemv_bf16.hip
deleted
100644 → 0
View file @
977247a7
/**
* 模仿 GEMM 接口的 GEMV,即 N=1。
* 编译:
* hipcc -std=c++17 -O3 --offload-arch=gfx936 gemv_bf16.hip -o gemv_bench
* 执行:
* HIP_VISIBLE_DEVICES=4 numactl -N 0 -m 0 ./gemv_bench -M 11264 -K 4096
*/
#include "gemv_utils.h"
#define WARP_SIZE 64
#define VEC_WIDTH 8
#define OFFSET(i, j, lda) ((i) + (j) * (lda))
#define OFFSET_T(i, j, lda) ((i) * (lda) + (j))
// 辅助结构体:把 float4 (128位) 重新解释为 8 个 bf16
struct __align__(16) bf16_x8 {
hip_bfloat16 vals[VEC_WIDTH];
};
// 替代 float4
typedef float __attribute__((ext_vector_type(4))) float4_native;
// 128-bit non-temporal load 或者 cached load
template <bool USE_NTL = false>
__device__ __forceinline__ bf16_x8 load_128b(const hip_bfloat16 *src) {
if constexpr (USE_NTL) {
// 把地址转换为 float4_native 指针
const float4_native *ptr = reinterpret_cast<const float4_native *>(src);
// 使用 Clang 内置 non-temporal load 函数,生成带有 slc/nt 修饰符的加载指令
float4_native tmp = __builtin_nontemporal_load(ptr);
// 把加载到的 128 位数据重新解释为 bf16_x8
return *reinterpret_cast<bf16_x8 *>(&tmp);
} else {
return *reinterpret_cast<const bf16_x8 *>(src);
}
}
/** y = alpha * A^T * x + 0 * y
* Naive 实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代
*/
__global__ void gemv_bf16_TN_naive(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = 0; k < K; k++) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* 向量化实现:
* - JKI
* - 每个线程算一个输出,即 I 循环的一次迭代。
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // output
if (m >= M)
return;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 每次读 VEC_WIDTH 个数据
for (int k = 0; k < K; k += VEC_WIDTH) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
y[m] = hip_bfloat16(alpha * sum);
return;
}
/** y = alpha * A^T * x + 0 * y
* Warp 归约:
* - JKI
* - 每个 warp 算一个输出,相当于用 warp size 作为 stride 沿着 K 方向 tiling。
* - Warp 内归约。
*/
__global__ void gemv_bf16_TN_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A, int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id; k < K; k += stride) {
float val_a = static_cast<float>(row_ptr[k]);
float val_x = static_cast<float>(x[k]);
sum += val_a * val_x;
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* Vec + warp:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
*/
template <bool USE_NTL = false>
__global__ void gemv_bf16_TN_vec_warp(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
*/
template <bool USE_NTL = false, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
sum[r] = 0.0f;
}
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < K; k += stride) {
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
*/
template <bool USE_NTL = false, int UNROLL = 4>
__global__ void gemv_bf16_TN_vec_warp_unroll(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
// const float beta, // set to 0
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
if (m >= M)
return;
const int stride = WARP_SIZE * VEC_WIDTH * UNROLL;
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 主循环临时变量
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
int k0 = lane_id * VEC_WIDTH;
int k = 0;
// 主循环
for (; k <= K - stride; k += stride) {
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int offset = k + k0 + u * (WARP_SIZE * VEC_WIDTH);
a_frag[u] = load_128b<USE_NTL>(&row_ptr[offset]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < K; k += WARP_SIZE * VEC_WIDTH) {
int offset = k + k0;
if (offset >= K)
continue;
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[offset]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x[offset]);
for (int i = 0; i < VEC_WIDTH; ++i) {
sum +=
static_cast<float>(a_vec.vals[i]) * static_cast<float>(x_vec.vals[i]);
}
}
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 负责写回
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp + 主循环 unroll + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 算一个输出,warp 内归约。
* - 主循环 unrolling。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int UNROLL = 4, int TILE_K = 4096>
__global__ void gemv_bf16_TN_vec_warp_unroll_shm(
int M, int K, const float alpha, const hip_bfloat16 *__restrict__ A,
int lda, const hip_bfloat16 *__restrict__ x, hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int m = blockIdx.x * (blockDim.x / WARP_SIZE) + warp_id;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 不会在 m>=M 时访问 A,因此不需要分支
const hip_bfloat16 *row_ptr = A + m * lda;
float sum = 0.0f;
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
// 每个线程加载 VEC_WIDTH 个元素
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献(有效的 warp 才参与计算)
if (m < M) {
const int warp_stride = WARP_SIZE * VEC_WIDTH;
const int unroll_stride = warp_stride * UNROLL;
int k = lane_id * VEC_WIDTH;
// 主循环:Unroll
for (; k <= tile_size - unroll_stride; k += unroll_stride) {
bf16_x8 a_frag[UNROLL];
bf16_x8 x_frag[UNROLL];
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int current_k = k + u * warp_stride;
a_frag[u] = load_128b<USE_NTL>(&row_ptr[kk + current_k]);
x_frag[u] = *reinterpret_cast<const bf16_x8 *>(&x_tile[current_k]);
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_frag[u].vals[i]) *
static_cast<float>(x_frag[u].vals[i]);
}
}
}
// Tail 循环
for (; k < tile_size; k += warp_stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 a_vec = load_128b<USE_NTL>(&row_ptr[kk + k]);
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum += static_cast<float>(a_vec.vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_a = static_cast<float>(row_ptr[kk + k + i]);
float val_x = static_cast<float>(x_tile[k + i]);
sum += val_a * val_x;
}
}
}
}
__syncthreads();
}
if (m >= M)
return;
// Warp 内归约
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum += __shfl_down(sum, offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
y[m] = hip_bfloat16(alpha * sum);
}
return;
}
/** y = alpha * A^T * x + 0 * y
* 单线程 vec + warp 处理多行 + shmem 缓存 x:
* - JKI
* - 每个线程每次读 VEC_WIDTH 个 bf16 数据(矩阵 A 可用 non-temporal load)。
* - 每个 warp 处理 ROWS_PER_WARP 个输出行,warp 内归约(每行独立归约)。
* - 每个 lane 维护 ROWS_PER_WARP 个累加器。
* - shmem 缓存 x,分块加载。
*/
template <bool USE_NTL = false, int TILE_K = 4096, int ROWS_PER_WARP = 2>
__global__ void gemv_bf16_TN_vec_warp_mr_shm(int M, int K, const float alpha,
const hip_bfloat16 *__restrict__ A,
int lda,
const hip_bfloat16 *__restrict__ x,
hip_bfloat16 *__restrict__ y) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// 每个 warp 处理 ROWS_PER_WARP 行
int m_base = blockIdx.x * (blockDim.x / WARP_SIZE) * ROWS_PER_WARP +
warp_id * ROWS_PER_WARP;
// 缓存 x 的一个 tile
__shared__ hip_bfloat16 x_tile[TILE_K];
// 每个 lane 维护 ROWS_PER_WARP 个累加器
float sum[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
sum[r] = 0.0f;
}
// 预先计算每一行的指针
const hip_bfloat16 *row_ptr[ROWS_PER_WARP];
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
int m = m_base + r;
// 越界时指向 A,确保地址有效,消除后续分支
row_ptr[r] = (m < M) ? (A + m * lda) : A;
}
// 外层循环遍历 K 维度的所有 tile
for (int kk = 0; kk < K; kk += TILE_K) {
int tile_size = min(TILE_K, K - kk);
// Step 1: 所有线程协作加载 x 的当前 tile 到 LDS
for (int i = threadIdx.x * VEC_WIDTH; i < tile_size;
i += blockDim.x * VEC_WIDTH) {
if (i + VEC_WIDTH <= tile_size) {
// 完整的向量化加载
*reinterpret_cast<bf16_x8 *>(&x_tile[i]) =
*reinterpret_cast<const bf16_x8 *>(&x[kk + i]);
} else {
// Tail 循环逐个加载
for (int j = 0; j < VEC_WIDTH && i + j < tile_size; ++j) {
x_tile[i + j] = x[kk + i + j];
}
}
}
__syncthreads();
// Step 2: 计算当前 tile 的贡献
// 每个 lane 处理 ROWS_PER_WARP 行
const int stride = WARP_SIZE * VEC_WIDTH;
for (int k = lane_id * VEC_WIDTH; k < tile_size; k += stride) {
if (k + VEC_WIDTH <= tile_size) {
// 完整的向量化计算
bf16_x8 x_vec = *reinterpret_cast<const bf16_x8 *>(&x_tile[k]);
bf16_x8 a_vecs[ROWS_PER_WARP];
// 批量加载,无分支
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
a_vecs[r] = load_128b<USE_NTL>(&row_ptr[r][kk + k]);
}
// 批量计算
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int i = 0; i < VEC_WIDTH; ++i) {
sum[r] += static_cast<float>(a_vecs[r].vals[i]) *
static_cast<float>(x_vec.vals[i]);
}
}
} else {
// Tail 循环
for (int i = 0; i < VEC_WIDTH && k + i < tile_size; ++i) {
float val_x = static_cast<float>(x_tile[k + i]);
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
float val_a = static_cast<float>(row_ptr[r][kk + k + i]);
sum[r] += val_a * val_x;
}
}
}
}
__syncthreads();
}
// Warp 内归约(每行独立归约)
#pragma unroll
for (int r = 0; r < ROWS_PER_WARP; ++r) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sum[r] += __shfl_down(sum[r], offset);
}
// Lane 0 写回结果
if (lane_id == 0) {
int m = m_base + r;
if (m < M) {
y[m] = hip_bfloat16(alpha * sum[r]);
}
}
}
return;
}
/// GEMV Microbenchmarks
/// y = alpha * A^T * x + beta * y
/// M = 输出维度 (11264)
/// K = 归约维度 (4096)
/// N = 1
int main(int argc, char **argv) {
bool do_verify = false;
float alpha = 1.0f;
int M = 11264;
int K = 4096;
// int N = 1; // Unused
int lda = K;
int block_size = 256;
if (char *value = getCmdOption(argv, argv + argc, "--verify")) {
do_verify = std::stoi(value) == 1;
}
if (char *value = getCmdOption(argv, argv + argc, "--alpha")) {
alpha = std::stof(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-M")) {
M = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-K")) {
K = std::stoi(value);
lda = K;
}
if (char *value = getCmdOption(argv, argv + argc, "--lda")) {
lda = std::stoi(value);
}
if (char *value = getCmdOption(argv, argv + argc, "-B")) {
block_size = std::stoi(value);
}
// transA=T,因此是行优先
size_t count_A = (size_t)M * lda;
size_t size_A = count_A * sizeof(hip_bfloat16);
size_t size_x = (size_t)K * sizeof(hip_bfloat16);
size_t size_y = (size_t)M * sizeof(hip_bfloat16);
// Host 内存分配
std::vector<hip_bfloat16> h_A(count_A);
std::vector<hip_bfloat16> h_x(K);
std::vector<hip_bfloat16> h_y(M);
// 随机初始数据
const float rand_max = static_cast<float>(RAND_MAX);
for (int i = 0; i < count_A; i++)
h_A[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < K; i++)
h_x[i] = hip_bfloat16(static_cast<float>(rand()) / rand_max);
for (int i = 0; i < M; i++)
h_y[i] = hip_bfloat16(0.0f);
// Device 内存分配
hip_bfloat16 *d_A, *d_x, *d_y;
checkHipErrors(hipMalloc(&d_A, size_A));
checkHipErrors(hipMalloc(&d_x, size_x));
checkHipErrors(hipMalloc(&d_y, size_y));
checkHipErrors(hipMemcpy(d_A, h_A.data(), size_A, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_x, h_x.data(), size_x, hipMemcpyHostToDevice));
checkHipErrors(hipMemcpy(d_y, h_y.data(), size_y, hipMemcpyHostToDevice));
// Kernel 注册表
std::vector<KernelCase> kernels;
constexpr bool NTL = true;
constexpr int UNROLL = 4;
constexpr int TILE_K = 4096;
constexpr int ROWS_PER_WARP = 2;
kernels.push_back(
{"naive", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_naive<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"vec8", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"vec8_ntl", [&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int grid = (M + block_size - 1) / block_size;
gemv_bf16_TN_vec<NTL><<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"warp", [&](int M, int K, float alpha, const hip_bfloat16 *A, int lda,
const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"vec8+warp", [&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back(
{"vec8_ntl+warp", [&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp<NTL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP +
warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<!NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP +
warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr<NTL, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<!NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp+unroll" + std::to_string(UNROLL),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll<NTL, UNROLL>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<!NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_shm<NTL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<!NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp+unroll" + std::to_string(UNROLL) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = (M + warps_per_block - 1) / warps_per_block;
gemv_bf16_TN_vec_warp_unroll_shm<NTL, UNROLL, TILE_K>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8+warp_mr" + std::to_string(ROWS_PER_WARP) + "+shm" +
std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP +
warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<!NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
kernels.push_back({"vec8_ntl+warp_mr" + std::to_string(ROWS_PER_WARP) +
"+shm" + std::to_string(TILE_K),
[&](int M, int K, float alpha, const hip_bfloat16 *A,
int lda, const hip_bfloat16 *x, hip_bfloat16 *y) {
int warps_per_block = block_size / WARP_SIZE;
int grid = ((M + ROWS_PER_WARP - 1) / ROWS_PER_WARP +
warps_per_block - 1) /
warps_per_block;
gemv_bf16_TN_vec_warp_mr_shm<NTL, TILE_K, ROWS_PER_WARP>
<<<grid, block_size>>>(M, K, alpha, A, lda, x, y);
}});
// 运行所有测试
run_benchmark(kernels, M, K, alpha, d_A, lda, d_x, d_y, do_verify);
// 清理
checkHipErrors(hipFree(d_A));
checkHipErrors(hipFree(d_x));
checkHipErrors(hipFree(d_y));
return 0;
}
\ No newline at end of file
evo2/bm/gemv_utils.h
View file @
b3a56179
...
...
@@ -37,7 +37,8 @@ inline char *getCmdOption(char **begin, char **end, const std::string &option) {
// --------------------------------------------------------------------------------
inline
void
gemv_cpu
(
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
h_A
,
int
lda
,
const
hip_bfloat16
*
h_x
,
hip_bfloat16
*
h_y
)
{
int
lda
,
const
hip_bfloat16
*
h_x
,
float
beta
,
hip_bfloat16
*
h_y
)
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
sum
=
0.0
f
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
...
...
@@ -45,7 +46,7 @@ inline void gemv_cpu(int M, int K, float alpha, const hip_bfloat16 *h_A,
float
val_x
=
static_cast
<
float
>
(
h_x
[
k
]);
sum
+=
val_a
*
val_x
;
}
h_y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
);
h_y
[
m
]
=
hip_bfloat16
(
alpha
*
sum
+
beta
*
h_y
[
m
]
);
}
return
;
...
...
@@ -85,9 +86,9 @@ inline bool verify_result(int M, const hip_bfloat16 *h_y_gpu,
// --------------------------------------------------------------------------------
// 定义统一的 Kernel Launcher 签名
using
KernelLauncher
=
std
::
function
<
void
(
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
hip_bfloat16
*
y
)
>
;
using
KernelLauncher
=
std
::
function
<
void
(
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
)
>
;
struct
KernelCase
{
std
::
string
name
;
...
...
@@ -96,7 +97,7 @@ struct KernelCase {
inline
void
run_benchmark
(
const
std
::
vector
<
KernelCase
>
&
cases
,
int
M
,
int
K
,
float
alpha
,
const
hip_bfloat16
*
A
,
int
lda
,
const
hip_bfloat16
*
x
,
hip_bfloat16
*
y
,
const
hip_bfloat16
*
x
,
float
beta
,
hip_bfloat16
*
y
,
bool
do_verify
)
{
std
::
cout
<<
"GEMV Benchmarks"
<<
std
::
endl
;
...
...
@@ -120,7 +121,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
hipMemcpyDeviceToHost
));
// 计算 CPU Reference
gemv_cpu
(
M
,
K
,
alpha
,
h_A
.
data
(),
lda
,
h_x
.
data
(),
h_y_ref
.
data
());
gemv_cpu
(
M
,
K
,
alpha
,
h_A
.
data
(),
lda
,
h_x
.
data
(),
beta
,
h_y_ref
.
data
());
}
// 列宽
...
...
@@ -143,7 +144,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
checkHipErrors
(
hipMemset
(
y
,
0
,
M
*
sizeof
(
hip_bfloat16
)));
// 运行一次
k
.
func
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
y
);
k
.
func
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
checkHipErrors
(
hipDeviceSynchronize
());
// 拷回结果
...
...
@@ -159,7 +160,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
// 2. Warmup
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
k
.
func
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
y
);
k
.
func
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}
checkHipErrors
(
hipDeviceSynchronize
());
...
...
@@ -167,7 +168,7 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
int
num_runs
=
1000
;
checkHipErrors
(
hipEventRecord
(
start
));
for
(
int
i
=
0
;
i
<
num_runs
;
++
i
)
{
k
.
func
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
y
);
k
.
func
(
M
,
K
,
alpha
,
A
,
lda
,
x
,
beta
,
y
);
}
checkHipErrors
(
hipEventRecord
(
stop
));
checkHipErrors
(
hipEventSynchronize
(
stop
));
...
...
@@ -184,8 +185,8 @@ inline void run_benchmark(const std::vector<KernelCase> &cases, int M, int K,
double
bytes_moved
=
(
double
)(
M
*
K
+
K
+
M
)
*
sizeof
(
hip_bfloat16
);
double
bw
=
bytes_moved
/
(
avg_ms
*
1e-3
)
/
1e9
;
printf
(
"%-38s %10.1f %10.2f %10.2f %8s
\n
"
,
k
.
name
.
c_str
(),
avg_ms
*
1e3
,
gflops
,
bw
,
result_status
.
c_str
());
printf
(
"%-38s %10.1f %10.2f %10.2f %8s
\n
"
,
k
.
name
.
c_str
(),
avg_ms
*
1e3
,
gflops
,
bw
,
result_status
.
c_str
());
}
std
::
cout
<<
std
::
string
(
w_table
,
'-'
)
<<
std
::
endl
;
...
...
evo2/bm/run-all.sh
View file @
b3a56179
#!/bin/bash
# BW150
export
HIP_VISIBLE_DEVICES
=
1
BIND_CMD
=
"numactl -N 0 -m 0"
make
# BW150
export
HIP_VISIBLE_DEVICES
=
4
hipprof numactl
-N
0
-m
0 ./gemv_bench
--verify
1
-M
11264
-K
4096
hipprof numactl
-N
0
-m
0 ./gemv_bench
--verify
1
-M
4096
-K
11264
hipprof numactl
-N
0
-m
0 ./gemv_bench
--verify
1
-M
12288
-K
4096
hipprof numactl
-N
0
-m
0 ./gemv_bench
--verify
1
-M
4096
-K
4096
\ No newline at end of file
if
[[
"
$*
"
==
*
"--trace"
*
]]
;
then
PROF_CMD
=
"hipprof --trace-off --pmc"
${
PROF_CMD
}
-o
log/pmc-k1
${
BIND_CMD
}
./gemv_bench
--verify
1
-M
11264
-K
4096
${
PROF_CMD
}
-o
log/pmc-k2
${
BIND_CMD
}
./gemv_bench
--verify
1
-M
4096
-K
11264
${
PROF_CMD
}
-o
log/pmc-k3
${
BIND_CMD
}
./gemv_bench
--verify
1
-M
12288
-K
4096
${
PROF_CMD
}
-o
log/pmc-k4
${
BIND_CMD
}
./gemv_bench
--verify
1
-M
4096
-K
4096
else
${
BIND_CMD
}
./gemv_bench
--verify
1
-M
11264
-K
4096
${
BIND_CMD
}
./gemv_bench
--verify
1
-M
4096
-K
11264
${
BIND_CMD
}
./gemv_bench
--verify
1
-M
12288
-K
4096
${
BIND_CMD
}
./gemv_bench
--verify
1
-M
4096
-K
4096
fi
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment