Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b74918bc
Commit
b74918bc
authored
Jan 06, 2025
by
ThomasNing
Browse files
compiled version of cross gpu connection
parents
3fcad951
1c45ca35
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
620 additions
and
442 deletions
+620
-442
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
...lti_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
+3
-3
example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
...62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
+5
-4
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
...iply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
+0
-3
example/CMakeLists.txt
example/CMakeLists.txt
+7
-0
example/README.md
example/README.md
+2
-0
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+12
-4
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+7
-7
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+31
-22
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+6
-3
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+87
-67
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+7
-7
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+114
-106
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+144
-72
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+126
-86
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+2
-2
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+1
-1
example/ck_tile/03_gemm/README.md
example/ck_tile/03_gemm/README.md
+3
-0
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+19
-21
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+1
-15
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+43
-19
No files found.
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
View file @
b74918bc
...
...
@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
break
;
default:
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
A0DataType
,
0
>
{});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
A1DataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
}
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
...
...
example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
View file @
b74918bc
...
...
@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification,
{
case
0
:
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
// values generated: -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
6
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
1.0
,
1.0
});
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0
.0
,
1
.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
5
.0
,
5
.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
1.0
,
1.0
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
...
...
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
View file @
b74918bc
...
...
@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
a1_device_buf
.
ToDevice
(
a1_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
...
@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
AccDataType
>
c_m_n
({
M
,
N
});
...
...
example/CMakeLists.txt
View file @
b74918bc
...
...
@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any DPP examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dpp"
)
message
(
"removing dpp example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT EX_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
...
...
example/README.md
0 → 100644
View file @
b74918bc
[
Back to the main page
](
../README.md
)
# Composable Kernel examples
\ No newline at end of file
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
b74918bc
...
...
@@ -2,10 +2,17 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
DTYPE_MAP
=
{
"fp16"
:
"ck_tile::fp16_t"
,
"bf16"
:
"ck_tile::bf16_t"
,
"fp8"
:
"ck_tile::fp8_t"
FWD_DTYPE_MAP
=
{
"fp16"
:
"FmhaFwdFp16"
,
"bf16"
:
"FmhaFwdBf16"
,
"fp8"
:
"FmhaFwdFp8"
,
"fp8fp16"
:
"FmhaFwdFp8Fp16"
,
"fp8bf16"
:
"FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP
=
{
"fp16"
:
"FmhaBwdFp16"
,
"bf16"
:
"FmhaBwdBf16"
}
MASK_IMPL
=
{
...
...
@@ -112,6 +119,7 @@ PIPELINE_MAP = {
PIPELINE_ENUM_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC"
,
"qr_nwarp_sshuffle"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
}
BOOL_MAP
=
{
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
b74918bc
...
...
@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
BWD_
DTYPE_MAP
[
dtype
],
F_spad0
=
BOOL_MAP
[
trait
.
spad
],
F_spad1
=
BOOL_MAP
[
spad1
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_deterministic
=
BOOL_MAP
[
trait
.
deterministic
])
...
...
@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
...
...
@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen
=
list
()
api_pool
=
FmhaBwdApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
...
...
@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
FMHA_BWD_DOT_DO_O_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
...
...
@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
...
...
@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_bm0
,
F_bn0
=
self
.
F_bn0
,
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
...
...
@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
b74918bc
...
...
@@ -44,13 +44,12 @@ FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile_{F_idx}
,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>
,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile_{F_idx}
,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>
,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
...
...
@@ -282,7 +281,7 @@ class FmhaFwdApiPool:
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_hdim
=
hdim
,
F_dtype
=
FWD_
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
...
...
@@ -301,20 +300,24 @@ class FmhaFwdTileSize:
F_bk1
:
int
# tile size along kv gemm unroll
F_bk0max
:
int
# total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0
:
int
# number of warps for gemm0 along q seqlen
F_rn0
:
int
# number of warps for gemm0 along k seqlen
F_rn0
:
int
# number of warps for gemm0 along k seqlen
F_rk0
:
int
# number of warps for gemm0 along head dim q (not used)
F_rm1
:
int
# number of warps for gemm1 along q seqlen
F_rn1
:
int
# number of warps for gemm1 along head dim v
F_rk1
:
int
# number of warps for gemm1 along k seqlen (not used)
F_wm
:
int
# warp size along m (warp size)
F_wn
:
int
# warp size along n
F_wk
:
int
# warp size along k
F_wm0
:
int
# gemm0 warp size along m
F_wn0
:
int
# gemm0 warp size along n
F_wk0
:
int
# gemm0 warp size along k
F_wm1
:
int
# gemm1 warp size along m
F_wn1
:
int
# gemm1 warp size along n
F_wk1
:
int
# gemm1 warp size along k
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bn1
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk0max
}
"
+
\
f
"_r
{
self
.
F_rm0
}
x
{
self
.
F_rn0
}
x
{
self
.
F_rk0
}
_r
{
self
.
F_rm1
}
x
{
self
.
F_rn1
}
x
{
self
.
F_rk1
}
"
+
\
f
"_w
{
self
.
F_wm
}
x
{
self
.
F_wn
}
x
{
self
.
F_wk
}
"
+
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
f
"_w
{
self
.
F_wm0
}
x
{
self
.
F_wn0
}
x
{
self
.
F_wk0
}
_w
{
self
.
F_wm1
}
x
{
self
.
F_wn1
}
x
{
self
.
F_wk1
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
class
FmhaFwdKernel
:
...
...
@@ -339,7 +342,7 @@ class FmhaFwdKernel:
FMHA_FWD_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
...
...
@@ -352,9 +355,12 @@ class FmhaFwdKernel:
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_wm
=
self
.
F_tile
.
F_wm
,
F_wn
=
self
.
F_tile
.
F_wn
,
F_wk
=
self
.
F_tile
.
F_wk
,
F_wm0
=
self
.
F_tile
.
F_wm0
,
F_wn0
=
self
.
F_tile
.
F_wn0
,
F_wk0
=
self
.
F_tile
.
F_wk0
,
F_wm1
=
self
.
F_tile
.
F_wm1
,
F_wn1
=
self
.
F_tile
.
F_wn1
,
F_wk1
=
self
.
F_tile
.
F_wk1
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
...
...
@@ -409,17 +415,17 @@ class FmhaFwdKernel:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
#
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32,
96, 4, 1, 1, 4, 1, 1, 32, 32, 16,
32, 32, 16,
-1),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
)
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
)
,
}
else
:
return
None
...
...
@@ -462,6 +468,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# no need lse/dropout kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
None
else
:
assert
False
return
pipelines
...
...
@@ -469,7 +478,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
gen
=
list
()
api_pool
=
FmhaFwdApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
View file @
b74918bc
...
...
@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
FWD_
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
...
...
@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bs
=
self
.
F_tile
.
F_bs
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bd
=
self
.
F_tile
.
F_bd
,
...
...
@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# rope/paged-kv is not supported
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
None
else
:
assert
False
return
pipelines
...
...
@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen
=
list
()
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
b74918bc
...
...
@@ -39,6 +39,7 @@ K0_MAX_SUBMAX_MAP = {
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr_nwarp_sshuffle"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
}
...
...
@@ -50,13 +51,12 @@ namespace {{
template <bool kHasUnevenSplits>
struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile
,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>
,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile
,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>
,
{F_vlayout}>;
using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...
...
@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
...
...
@@ -161,9 +161,8 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim},
{F_bm0},
{F_bn1},
{F_mode},
{F_bn1},
fmha_trait>;
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
...
...
@@ -177,9 +176,11 @@ using fmha_epilogue =
false, false>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVCombineKernel<
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<
fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>,
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
...
...
@@ -192,7 +193,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}};
}}
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode},
{F_bm0},
{F_bn1},
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
#include <iostream>
...
...
@@ -231,11 +232,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
if(s.log_level_ > 0)
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
);
}}
...
...
@@ -247,12 +248,31 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) &&
(t.has_lse == {F_lse}) &&
(t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
"""
...
...
@@ -292,7 +312,7 @@ class FmhaFwdSplitKVApiTrait:
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
spad
==
't'
:
return
'true'
# always support
else
:
return
'true'
elif
self
.
pipeline_tag
in
[
'qr'
]:
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_nwarp_sshuffle'
]:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bm0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0'
else
:
assert
False
...
...
@@ -303,7 +323,7 @@ class FmhaFwdSplitKVApiTrait:
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
skpad
==
't'
:
return
f
'a.seqlen_k == 0 || a.seqlen_k %
{
self
.
bn0
}
!= 0'
else
:
return
f
'a.seqlen_k != 0 && a.seqlen_k %
{
self
.
bn0
}
== 0'
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_
fp8
'
]:
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_
nwarp_sshuffle
'
]:
if
self
.
skpad
==
't'
:
return
f
'true /*a.seqlen_k %
{
self
.
bn0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
assert
False
...
...
@@ -314,7 +334,7 @@ class FmhaFwdSplitKVApiTrait:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_nwarp_sshuffle'
]:
bk0submax
=
K0_MAX_SUBMAX_MAP
[
self
.
bk0max
]
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
bk0submax
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
bk0submax
}
== 0'
...
...
@@ -326,7 +346,7 @@ class FmhaFwdSplitKVApiTrait:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_nwarp_sshuffle'
]:
bk0submax
=
K0_MAX_SUBMAX_MAP
[
self
.
bk0max
]
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
bk0submax
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
bk0submax
}
== 0'
...
...
@@ -421,11 +441,11 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_hdim
=
hdim
,
F_dtype
=
FWD_
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
...
...
@@ -437,12 +457,11 @@ class FmhaFwdSplitKVApiPool:
@
dataclass
class
FmhaFwdSplitKVCombineTileSize
:
F_bm0
:
int
# tile size along q seqlen
F_bn1
:
int
# tile size along v head_dim
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_
bm0
}
x
{
self
.
F_
bn1
}
"
+
\
return
f
"b
{
self
.
F_bn1
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
...
...
@@ -462,7 +481,7 @@ class FmhaFwdSplitKVKernel:
FMHA_FWD_SPLITKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
...
...
@@ -475,14 +494,17 @@ class FmhaFwdSplitKVKernel:
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_wm
=
self
.
F_tile
.
F_wm
,
F_wn
=
self
.
F_tile
.
F_wn
,
F_wk
=
self
.
F_tile
.
F_wk
,
F_wm0
=
self
.
F_tile
.
F_wm0
,
F_wn0
=
self
.
F_tile
.
F_wn0
,
F_wk0
=
self
.
F_tile
.
F_wk0
,
F_wm1
=
self
.
F_tile
.
F_wm1
,
F_wn1
=
self
.
F_tile
.
F_wn1
,
F_wk1
=
self
.
F_tile
.
F_wk1
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
...
...
@@ -542,8 +564,7 @@ class FmhaFwdSplitKVCombineKernel:
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_dtype
=
FWD_DTYPE_MAP
[
self
.
F_dtype
],
F_bn1
=
self
.
F_tile
.
F_bn1
,
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
...
...
@@ -567,17 +588,17 @@ class FmhaFwdSplitKVCombineKernel:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdTileSize
(
32
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
32
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
16
,
16
,
16
,
16
,
16
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
16
,
16
,
16
,
-
1
),
#
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16,
16, 16, 16,
-1),
'128'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
16
,
16
,
16
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
)
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
)
,
}
else
:
return
None
...
...
@@ -585,17 +606,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdSplitKVCombineTileSize
(
16
,
16
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
32
,
-
1
),
## '96' : FmhaFwdSplitKVCombineTileSize(32,
64,
-1),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
64
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
128
,
-
1
),
'32'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
#
## '96'
: FmhaFwdSplitKVCombineTileSize(32, -1),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
}
else
:
return
None
...
...
@@ -614,27 +635,29 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
# TODO: use async pipeline when compiler is more stable
for
mask
,
bias
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
]):
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
else
:
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
if
receipt
==
1
:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/paged-kv kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
squant
,
'f'
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
'f'
,
mask
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
None
else
:
assert
False
return
pipelines
...
...
@@ -642,7 +665,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen
=
list
()
api_pool
=
FmhaFwdSplitKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
...
...
@@ -655,9 +678,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if
pipeline
.
F_pagedkv
==
't'
:
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
...
...
@@ -705,7 +725,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
b74918bc
...
...
@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
}
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
auto
get_elimit
(
ck_tile
::
index_t
/*hdim_q*/
,
ck_tile
::
index_t
/*hdim_v*/
)
{
double
rtol
=
1e-2
;
...
...
@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
}
template
<
>
auto
get_elimit
<
ck_tile
::
b
f16
_t
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
auto
get_elimit
<
FmhaBwdB
f16
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
...
...
@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
...
...
@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
auto
seqstart_q_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_q
);
const
auto
seqstart_k_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_k
);
using
TypeConfig
=
FmhaBwdTypeConfig
<
DataType
>
;
using
TypeConfig
=
FmhaBwdTypeConfig
<
DataType
Config
>
;
using
QDataType
=
typename
TypeConfig
::
QDataType
;
using
KDataType
=
typename
TypeConfig
::
KDataType
;
...
...
@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
Config
>
(
hdim_q
,
hdim_v
);
bool
dq_cur_pass
=
ck_tile
::
check_err
(
dq_host_result
,
dq_host_ref
,
std
::
string
(
"Error: QGrad Incorrect results!"
),
...
...
@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaBwdFp16
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
)
{
return
run
<
ck_tile
::
b
f16
_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaBwdB
f16
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
b74918bc
...
...
@@ -14,11 +14,19 @@
#include <utility>
#include <variant>
struct
FmhaBwdFp16
{
};
struct
FmhaBwdBf16
{
};
template
<
typename
DataType
>
struct
FmhaBwdTypeConfig
;
template
<
>
struct
FmhaBwdTypeConfig
<
ck_tile
::
half_t
>
struct
FmhaBwdTypeConfig
<
FmhaBwdFp16
>
{
using
QDataType
=
ck_tile
::
half_t
;
using
KDataType
=
ck_tile
::
half_t
;
...
...
@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
};
template
<
>
struct
FmhaBwdTypeConfig
<
ck_tile
::
b
f16
_t
>
struct
FmhaBwdTypeConfig
<
FmhaBwdB
f16
>
{
using
QDataType
=
ck_tile
::
bf16_t
;
using
KDataType
=
ck_tile
::
bf16_t
;
...
...
@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments
if
constexpr
(
FmhaBwdDQDKDVKernel
::
kIsGroupMode
)
{
return
FmhaBwdDQDKDVKernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
lse_ptr
,
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
drop_seed_offset
);
return
FmhaBwdDQDKDVKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
lse_ptr
,
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
drop_seed_offset
);
}
else
{
// create batch mode kernel arguments
return
FmhaBwdDQDKDVKernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
lse_ptr
,
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_do
,
args
.
batch_stride_lsed
,
args
.
batch_stride_dq_acc
,
args
.
batch_stride_dk
,
args
.
batch_stride_dv
,
args
.
batch_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
drop_seed_offset
);
return
FmhaBwdDQDKDVKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
lse_ptr
,
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_do
,
args
.
batch_stride_lsed
,
args
.
batch_stride_dq_acc
,
args
.
batch_stride_dk
,
args
.
batch_stride_dv
,
args
.
batch_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
drop_seed_offset
);
}
}();
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
b74918bc
...
...
@@ -3,6 +3,7 @@
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp"
...
...
@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"v"
,
"1"
,
"
weather do CPU validation or not
"
)
arg_parser
.
insert
(
"v"
,
"1"
,
"
0:no validation, 2:cpu validation, 2:gpu validation(experimental)
"
)
.
insert
(
"mode"
,
"0"
,
"kernel mode. 0:batch, 1:group"
)
.
insert
(
"b"
,
"2"
,
"batch size"
)
.
insert
(
"h"
,
"8"
,
"num of head, for q"
)
...
...
@@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[])
"-1 to choose s_knew in [1, s] randomly."
)
.
insert
(
"s_kpad"
,
"-1"
,
"seqlen_k stride between 2
token
s, currently used in group-mode only
\n
"
"seqlen_k stride between 2
batche
s, currently used in group-mode only
\n
"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
\n
"
"along seqlen, instead of packed. same as xformer kv_padding"
)
.
insert
(
"d"
,
"128"
,
"head dim for q, k"
)
...
...
@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[])
}
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
auto
get_elimit
(
std
::
string
/*init_method*/
)
{
double
rtol
=
1e-3
;
...
...
@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/)
}
template
<
>
auto
get_elimit
<
ck_tile
::
b
f16
_t
>
(
std
::
string
/*init_method*/
)
auto
get_elimit
<
FmhaFwdB
f16
>
(
std
::
string
/*init_method*/
)
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
...
...
@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
}
template
<
>
auto
get_elimit
<
ck_tile
::
fp8_t
>
(
std
::
string
init_method
)
auto
get_elimit
<
FmhaFwdFp8
>
(
std
::
string
init_method
)
{
if
(
init_method
==
"ui"
||
init_method
==
"ni"
)
{
...
...
@@ -261,7 +262,7 @@ int override_num_splits_if_necessary(
return
num_splits
;
}
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
...
...
@@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
#if !CK_TILE_FMHA_FWD_APPENDKV_API
if
(
seqlen_knew
!=
0
)
{
std
::
cerr
<<
"kvcache is not supported. ignoring the 's_knew' option"
<<
std
::
endl
;
std
::
cerr
<<
"fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
<<
std
::
endl
;
seqlen_knew
=
0
;
}
#endif
...
...
@@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
ck_tile
::
index_t
rotary_dim
=
arg_parser
.
get_int
(
"rotary_dim"
);
if
constexpr
(
!
(
std
::
is_same_v
<
DataType
,
ck_tile
::
f
p16
_t
>
||
std
::
is_same_v
<
DataType
,
ck_tile
::
b
f16
_t
>
))
if
constexpr
(
!
(
std
::
is_same_v
<
DataType
Config
,
FmhaFwdF
p16
>
||
std
::
is_same_v
<
DataType
Config
,
FmhaFwdB
f16
>
))
{
if
(
0
<
rotary_dim
)
{
...
...
@@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
rotary_dim
=
0
;
}
#endif
// to use fmha_fwd_appendkv(), make sure it's in batch mode
const
bool
need_append_kvcache
=
(
0
<
seqlen_knew
||
0
<
rotary_dim
);
if
(
need_append_kvcache
&&
mode
==
mode_enum
::
group
)
{
std
::
cerr
<<
"fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option"
<<
std
::
endl
;
mode
=
mode_enum
::
batch
;
}
if
(
!
(
rotary_dim
<=
hdim_q
))
{
std
::
cerr
<<
"rotary_dim should be less than or equal to head dim for q"
<<
std
::
endl
;
...
...
@@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
#e
ndif
if
(
0
<
page_block_size
&&
use_cache_batch_idx
)
#e
lse
if
(
use_cache_batch_idx
)
{
std
::
cerr
<<
"paged-kvcache does not support cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
if
(
0
<
page_block_size
)
{
std
::
cerr
<<
"paged-kvcache does not support cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
else
if
(
mode
==
mode_enum
::
group
)
{
std
::
cerr
<<
"group mode will not use cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
}
// the input tensor layout for kvcache is same as batch mode
const
bool
need_append_kvcache
=
(
0
<
seqlen_knew
||
0
<
rotary_dim
);
#endif
const
bool
use_kvcache
=
(
need_append_kvcache
||
use_cache_batch_idx
||
0
<
page_block_size
);
if
(
use_kvcache
&&
mode
!=
mode_enum
::
batch
)
{
std
::
cerr
<<
"kvcache enabled. ignoring the 'mode' option"
<<
std
::
endl
;
mode
=
mode_enum
::
batch
;
}
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
decode_seqlen
(
mode
,
...
...
@@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
arg_parser
.
get_str
(
"s_k"
),
arg_parser
.
get_str
(
"s_kpad"
),
/*seqlen_k_min=*/
0
<
seqlen_knew
?
seqlen_knew
:
0
,
use
_kvcache
);
need_append
_kvcache
);
// compute kvcache seqlen_k (before appending knew/vnew)
auto
cache_seqlen_ks
=
seqlen_ks
;
std
::
transform
(
cache_seqlen_ks
.
begin
(),
...
...
@@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
atoi
(
squant_str
.
c_str
())
!=
0
?
true
:
false
;
}();
float
range_q
=
arg_parser
.
get_float
(
"range_q"
);
float
range_k
=
arg_parser
.
get_float
(
"range_k"
);
float
range_v
=
arg_parser
.
get_float
(
"range_v"
);
float
range_p
=
arg_parser
.
get_float
(
"range_p"
);
float
range_o
=
arg_parser
.
get_float
(
"range_o"
);
float
dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
DataType
>::
max
());
float
scale_p
=
1.
f
;
float
scale_o
=
1.
f
;
if
(
squant
)
{
scale_s
=
scale_s
*
(
range_q
/
dtype_max
)
*
(
range_k
/
dtype_max
);
scale_p
=
dtype_max
/
range_p
;
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
scale_o
=
range_p
*
range_v
/
range_o
/
dtype_max
;
}
std
::
string
vlayout
=
arg_parser
.
get_str
(
"vlayout"
);
bool
lse
=
arg_parser
.
get_bool
(
"lse"
);
...
...
@@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
bool
s_randval
=
false
;
if
(
p_drop
>
0.0
f
&&
do_validation
)
if
(
p_drop
>
0.0
f
&&
do_validation
!=
0
)
{
s_randval
=
true
;
}
...
...
@@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
auto
seqstart_k_host
=
to_seqstarts
(
seqlen_ks
);
const
auto
seqstart_k_with_padding_host
=
to_seqstarts
(
seqlen_kpads
);
using
TypeConfig
=
FmhaFwdTypeConfig
<
DataType
>
;
using
TypeConfig
=
FmhaFwdTypeConfig
<
DataType
Config
>
;
using
QDataType
=
typename
TypeConfig
::
QDataType
;
using
KDataType
=
typename
TypeConfig
::
KDataType
;
...
...
@@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
OaccDataType
=
typename
TypeConfig
::
OaccDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
float
range_q
=
arg_parser
.
get_float
(
"range_q"
);
float
range_k
=
arg_parser
.
get_float
(
"range_k"
);
float
range_v
=
arg_parser
.
get_float
(
"range_v"
);
float
range_p
=
arg_parser
.
get_float
(
"range_p"
);
float
range_o
=
arg_parser
.
get_float
(
"range_o"
);
float
q_dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
QDataType
>::
max
());
float
k_dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
KDataType
>::
max
());
float
v_dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
VDataType
>::
max
());
float
p_dtype_max
=
v_dtype_max
;
// assume p and v is the same type
float
o_dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
ODataType
>::
max
());
float
scale_p
=
1.
f
;
float
scale_o
=
1.
f
;
if
(
squant
)
{
scale_s
=
scale_s
*
(
range_q
/
q_dtype_max
)
*
(
range_k
/
k_dtype_max
);
scale_p
=
p_dtype_max
/
range_p
;
scale_o
=
(
o_dtype_max
/
range_o
)
*
(
range_p
/
p_dtype_max
)
*
(
range_v
/
v_dtype_max
);
}
// accumulation numbers for performance evaluation
std
::
size_t
flop
=
0
,
num_byte
=
0
;
auto
max_seqlen_q
=
...
...
@@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
if
(
init_method
==
"ufq"
||
init_method
==
"uf:q"
||
init_method
==
"3"
)
// suitable for fp8 quantization
{
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
vnew_host
);
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
-
q_
dtype_max
,
q_
dtype_max
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
k_
dtype_max
,
k_
dtype_max
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
k_
dtype_max
,
k_
dtype_max
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
v_
dtype_max
,
v_
dtype_max
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
v_
dtype_max
,
v_
dtype_max
,
seed
}(
vnew_host
);
// bias_fp8 = qscale_bias * bias_fp32
float
qscale_bias
=
(
dtype_max
/
range_q
)
*
(
dtype_max
/
range_k
);
float
qscale_bias
=
(
q_
dtype_max
/
range_q
)
*
(
k_
dtype_max
/
range_k
);
// Assume bias is in [-1.f, 1.f] in original fp32
ck_tile
::
FillUniformDistribution
<
BiasDataType
>
{
-
qscale_bias
,
qscale_bias
,
seed
}(
bias_host
);
}
...
...
@@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqlen_k_buf
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
seqlen_k_buf
((
mode
==
mode_enum
::
batch
&&
use_kvcache
)
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
cache_seqlen_k_buf
(
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
...
...
@@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_k
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
.
data
()
:
seqstart_k_with_padding_host
.
data
());
seqlen_k_buf
.
ToDevice
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
data
()
:
nullptr
);
seqlen_k_buf
.
ToDevice
((
mode
==
mode_enum
::
batch
&&
use_kvcache
)
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
data
()
:
nullptr
);
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
...
...
@@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
mode
==
mode_enum
::
group
?
seqstart_q
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqstart_k_ptr
=
(
mode
==
mode_enum
::
group
?
seqstart_k
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k_ptr
=
(
use_kvcache
||
0
<=
k_paddings_
[
0
]
?
seqlen_k_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k_ptr
=
((
mode
==
mode_enum
::
batch
&&
use_kvcache
)
||
0
<=
k_paddings_
[
0
]
?
seqlen_k_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k
=
shape_seqlen_k
;
// unused in group mode (or kvcache enabled)
args
.
max_seqlen_q
=
max_seqlen_q
;
...
...
@@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
page_block_size
=
page_block_size
;
args
.
is_gappy
=
false
;
// use 'false' for flash-attention integration
args
.
cache_batch_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
...
...
@@ -1100,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
std
::
setprecision
(
2
)
<<
tflops
<<
" TFlops, "
<<
std
::
setprecision
(
2
)
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
if
(
!
do_validation
)
if
(
do_validation
==
0
)
{
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
true
;
}
if
(
do_validation
==
2
)
{
// NOTE: use gpu to do validation
ck_tile
::
naive_attention_fwd_traits
naive_t
;
naive_t
.
q_type
=
data_type
;
naive_t
.
k_type
=
data_type
;
naive_t
.
v_type
=
data_type
;
naive_t
.
o_type
=
data_type
;
naive_t
.
q_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
k_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
v_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
o_layout
=
o_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
variation
=
0
;
// TODO?
ck_tile
::
DeviceMem
o_naive_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
naive_attention_fwd_args
naive_a
;
naive_a
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
naive_a
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
naive_a
.
v_ptr
=
v_buf
.
GetDeviceBuffer
();
naive_a
.
o_ptr
=
o_naive_buf
.
GetDeviceBuffer
();
naive_a
.
scale_s
=
scale_s
;
naive_a
.
context_len_ptr
=
nullptr
;
// used when seqlen kv come from a pointer
naive_a
.
page_table_ptr
=
nullptr
;
// [batch, num_blocks] seqlen_kv is in different block(paged attn)
naive_a
.
hdim
=
hdim_q
;
naive_a
.
hdim_v
=
hdim_v
;
// could be cross-attn, where V and Q/K hdim are different
naive_a
.
batch_q
=
batch
;
naive_a
.
batch_kv
=
batch
;
naive_a
.
batch_ratio_kv
=
1
;
// batch_q / batch_kv
naive_a
.
seqlen_q
=
seqlen_qs
[
0
];
naive_a
.
seqlen_kv
=
seqlen_ks
[
0
];
// if context_len_ptr is not nullptr, ignore this field
naive_a
.
nhead_q
=
nhead
;
naive_a
.
nhead_kv
=
nhead_k
;
naive_a
.
nhead_ratio_kv
=
naive_a
.
nhead_q
/
naive_a
.
nhead_kv
;
// nhead_q / nhead_kv
naive_a
.
page_size
=
0
;
// if paged, the seqlen-kv for each block
ck_tile
::
stream_config
naive_s
{};
naive_attention_fwd
(
naive_t
,
naive_a
,
naive_s
);
auto
o_naive_ref
=
o_naive_buf
.
ToHost
<
ODataType
>
();
o_buf
.
FromDevice
(
o_host
.
data
());
// TODO: ugly
auto
[
rtol_
,
atol_
]
=
get_elimit
<
DataTypeConfig
>
(
init_method
);
bool
pass_
=
ck_tile
::
check_err
(
o_host
,
o_naive_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol_
,
atol_
);
std
::
cout
<<
", valid:"
<<
(
pass_
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
return
pass_
;
}
o_buf
.
FromDevice
(
o_host
.
data
());
lse_buf
.
FromDevice
(
lse_host
.
data
());
randval_buf
.
FromDevice
(
randval_host
.
data
());
auto
p_compute_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
if
constexpr
(
std
::
is_same_v
<
DataType
Config
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
scales
{
scale_p
};
else
return
ck_tile
::
identity
{};
}();
auto
oacc_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
if
constexpr
(
std
::
is_same_v
<
DataType
Config
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
ck_tile
::
scales
{
scale_o
});
else
...
...
@@ -1168,7 +1240,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
decltype
(
q_host_ref
)
q_host_ref_ro
(
q_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
real_seqlen_q
);
ck_tile
::
reference_batched_rotary_position_embedding
(
...
...
@@ -1184,13 +1256,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
]
%
page_block_size
,
i
[
2
]);
});
}
else
{
}
else
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
1
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
2
]);
});
}
}
else
#endif
#endif
{
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
...
...
@@ -1211,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
knew_host_ref_ro
.
emplace
(
knew_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
seqlen_knew
);
ck_tile
::
reference_batched_rotary_position_embedding
(
...
...
@@ -1233,19 +1305,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
0
<
page_block_size
)
{
if
(
is_v_rowmajor
)
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
,
i
[
1
]);
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
,
i
[
1
]);
});
}
else
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
2
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
1
]);
});
}
}
else
else
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
%
page_block_size
);
});
}
else
{
...
...
@@ -1440,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b_idx
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
init_method
);
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
Config
>
(
init_method
);
bool
cur_pass
=
ck_tile
::
check_err
(
o_host_result
,
o_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
pass
&=
cur_pass
;
...
...
@@ -1497,15 +1569,15 @@ int main(int argc, char* argv[])
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaFwdFp16
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
)
{
return
run
<
ck_tile
::
b
f16
_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaFwdB
f16
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"fp8"
)
{
return
run
<
ck_tile
::
fp8_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaFwdFp8
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
b74918bc
...
...
@@ -16,11 +16,35 @@
#include <utility>
#include <variant>
struct
FmhaFwdFp16
{
};
struct
FmhaFwdBf16
{
};
struct
FmhaFwdFp8
{
};
struct
FmhaFwdBf8
{
};
struct
FmhaFwdFp8Fp16
{
};
struct
FmhaFwdFp8Bf16
{
};
template
<
typename
DataType
>
struct
FmhaFwdTypeConfig
;
template
<
>
struct
FmhaFwdTypeConfig
<
ck_tile
::
half_t
>
struct
FmhaFwdTypeConfig
<
FmhaFwdFp16
>
{
using
QDataType
=
ck_tile
::
half_t
;
using
KDataType
=
ck_tile
::
half_t
;
...
...
@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
};
template
<
>
struct
FmhaFwdTypeConfig
<
ck_tile
::
b
f16
_t
>
struct
FmhaFwdTypeConfig
<
FmhaFwdB
f16
>
{
using
QDataType
=
ck_tile
::
bf16_t
;
using
KDataType
=
ck_tile
::
bf16_t
;
...
...
@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
};
template
<
>
struct
FmhaFwdTypeConfig
<
ck_tile
::
fp8_t
>
struct
FmhaFwdTypeConfig
<
FmhaFwdFp8
>
{
using
QDataType
=
ck_tile
::
fp8_t
;
using
KDataType
=
ck_tile
::
fp8_t
;
...
...
@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
};
template
<
>
struct
FmhaFwdTypeConfig
<
ck_tile
::
bf8_t
>
struct
FmhaFwdTypeConfig
<
FmhaFwdBf8
>
{
using
QDataType
=
ck_tile
::
bf8_t
;
using
KDataType
=
ck_tile
::
bf8_t
;
...
...
@@ -165,6 +189,8 @@ struct fmha_fwd_splitkv_args
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
bool
is_gappy
;
// differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
// nullptr.
const
void
*
cache_batch_idx
;
...
...
@@ -173,9 +199,21 @@ struct fmha_fwd_splitkv_args
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// or kargs.seqlen_k_ptr[b]
//
// batch mode (kvcache):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
//
// when is_gappy=true:
// seqlen_k = kargs.seqlen_k_ptr[b]
// seqstart_k_ptr[b] now store local offset of each batch
//
// when is_gappy=false:
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
...
...
@@ -251,7 +289,7 @@ struct fmha_fwd_appendkv_args
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
const
void
*
cache_batch_idx
;
// only used if block_table_ptr is nullptr -> batch mode (kvcache)
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
...
...
@@ -278,87 +316,87 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
// create group mode kernel arguments
if
constexpr
(
FmhaKernel
::
kIsGroupMode
)
{
return
FmhaKernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_o
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
return
FmhaKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_o
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
{
// create batch mode kernel arguments
return
FmhaKernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_o
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse
,
args
.
batch_stride_o
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
return
FmhaKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_o
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse
,
args
.
batch_stride_o
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
...
...
@@ -389,6 +427,10 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
is_gappy
,
args
.
scale_s
,
args
.
scale_p
,
args
.
stride_q
,
...
...
@@ -667,7 +709,6 @@ std::string fmha_fwd_splitkv_get_name_();
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN1_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
...
...
@@ -678,7 +719,6 @@ struct fmha_fwd_splitkv_combine_traits_
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
...
...
example/ck_tile/01_fmha/utils.hpp
View file @
b74918bc
...
...
@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
std
::
string
k_val
,
std
::
string
k_pad_val
,
ck_tile
::
index_t
seqlen_k_min
=
0
,
bool
use_kvcache
=
false
,
bool
need_append_kvcache
=
false
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
...
...
@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
if
(
1
<
batch
&&
use
_kvcache
)
if
(
1
<
batch
&&
need_append
_kvcache
)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
...
...
example/ck_tile/03_gemm/CMakeLists.txt
View file @
b74918bc
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_
gemm_mem_pipeline
EXCLUDE_FROM_ALL
gemm_mem_pipeline
.cpp
)
add_executable
(
tile_example_
universal_gemm
EXCLUDE_FROM_ALL
universal_gemm
.cpp
)
example/ck_tile/03_gemm/README.md
View file @
b74918bc
...
...
@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j
```
This will result in an executable
`build/bin/tile_example_gemm_basic`
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
b74918bc
...
...
@@ -15,12 +15,13 @@
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
...
...
@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPad
A
,
kPad
B
,
kPad
M
,
kPad
N
,
kTilePermute
,
kOutputRank
,
1
,
...
...
@@ -65,32 +66,29 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPad
A
,
kPad
B
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPad
M
,
kPad
N
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPad
A
,
kPad
B
,
kPad
C
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPad
M
,
kPad
N
,
kPad
K
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
b74918bc
...
...
@@ -51,20 +51,6 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
gemm_basic_args
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
...
...
@@ -89,4 +75,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float
gemm_calc
(
gemm_basic_a
rgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
b74918bc
...
...
@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_repeat
)
{
gemm_basic_a
rgs
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
ck_tile
::
GemmHostA
rgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k
_
batch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
string
op_name
{
"Gemm
{
MemBoundPipeline}"
}
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
N
*
K
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std::cout << "
Run
" << op_name << "
kernel
with
M
=
" << M << "
N
=
" << N << "
K
=
" << K
std
::
cout
<<
"Run
Gemm
kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
...
...
@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
// TODO: add different init types
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
...
...
@@ -164,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
));
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout>(
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
CLayout
>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
);
ck_tile
::
hip_check_error
(
hipMemcpy
(
c_m_n_gpu_buf_ref
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
));
ck_tile
::
hip_check_error
(
hipFree
(
d_A
));
ck_tile
::
hip_check_error
(
hipFree
(
d_B
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
...
...
@@ -202,14 +224,16 @@ int run_gemm_example(int argc, char* argv[])
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else if(a_layout == "
C
" && b_layout == "
C
")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "
C
" && b_layout == "
R
")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// else if(a_layout == "C" && b_layout == "C")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
...
...
Prev
1
2
3
4
5
6
7
8
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment