Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b74918bc
Commit
b74918bc
authored
Jan 06, 2025
by
ThomasNing
Browse files
compiled version of cross gpu connection
parents
3fcad951
1c45ca35
Changes
486
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
620 additions
and
442 deletions
+620
-442
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
...lti_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
+3
-3
example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
...62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
+5
-4
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
...iply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
+0
-3
example/CMakeLists.txt
example/CMakeLists.txt
+7
-0
example/README.md
example/README.md
+2
-0
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+12
-4
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+7
-7
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+31
-22
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+6
-3
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+87
-67
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+7
-7
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+114
-106
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+144
-72
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+126
-86
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+2
-2
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+1
-1
example/ck_tile/03_gemm/README.md
example/ck_tile/03_gemm/README.md
+3
-0
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+19
-21
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+1
-15
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+43
-19
No files found.
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
View file @
b74918bc
...
@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
default:
default:
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
A0DataType
,
0
>
{});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
A1DataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
}
}
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
...
...
example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
View file @
b74918bc
...
@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification,
...
@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification,
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
// values generated: -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
6
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
1.0
,
1.0
});
break
;
break
;
default:
default:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0
.0
,
1
.0
});
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
5
.0
,
5
.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
1.0
,
1.0
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
...
...
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
View file @
b74918bc
...
@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
...
@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
a1_device_buf
.
ToDevice
(
a1_m_k
.
mData
.
data
());
a1_device_buf
.
ToDevice
(
a1_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
...
@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
if
(
do_verification
)
{
{
Tensor
<
AccDataType
>
c_m_n
({
M
,
N
});
Tensor
<
AccDataType
>
c_m_n
({
M
,
N
});
...
...
example/CMakeLists.txt
View file @
b74918bc
...
@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#Do not build any DPP examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dpp"
)
message
(
"removing dpp example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT EX_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
if
(
NOT EX_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
...
...
example/README.md
0 → 100644
View file @
b74918bc
[
Back to the main page
](
../README.md
)
# Composable Kernel examples
\ No newline at end of file
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
b74918bc
...
@@ -2,10 +2,17 @@
...
@@ -2,10 +2,17 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
# generate kernel instances to speed up compilation
DTYPE_MAP
=
{
FWD_DTYPE_MAP
=
{
"fp16"
:
"ck_tile::fp16_t"
,
"fp16"
:
"FmhaFwdFp16"
,
"bf16"
:
"ck_tile::bf16_t"
,
"bf16"
:
"FmhaFwdBf16"
,
"fp8"
:
"ck_tile::fp8_t"
"fp8"
:
"FmhaFwdFp8"
,
"fp8fp16"
:
"FmhaFwdFp8Fp16"
,
"fp8bf16"
:
"FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP
=
{
"fp16"
:
"FmhaBwdFp16"
,
"bf16"
:
"FmhaBwdBf16"
}
}
MASK_IMPL
=
{
MASK_IMPL
=
{
...
@@ -112,6 +119,7 @@ PIPELINE_MAP = {
...
@@ -112,6 +119,7 @@ PIPELINE_MAP = {
PIPELINE_ENUM_MAP
=
{
PIPELINE_ENUM_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
"qr"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC"
,
"qr_nwarp_sshuffle"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
}
}
BOOL_MAP
=
{
BOOL_MAP
=
{
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
b74918bc
...
@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
...
@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
BWD_
DTYPE_MAP
[
dtype
],
F_spad0
=
BOOL_MAP
[
trait
.
spad
],
F_spad1
=
BOOL_MAP
[
spad1
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad0
=
BOOL_MAP
[
trait
.
spad
],
F_spad1
=
BOOL_MAP
[
spad1
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_deterministic
=
BOOL_MAP
[
trait
.
deterministic
])
F_deterministic
=
BOOL_MAP
[
trait
.
deterministic
])
...
@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
...
@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
.
format
(
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
...
@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen
=
list
()
gen
=
list
()
api_pool
=
FmhaBwdApiPool
(
mask_impl
)
api_pool
=
FmhaBwdApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
...
@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
FMHA_BWD_DOT_DO_O_KERNEL_BODY
.
format
(
FMHA_BWD_DOT_DO_O_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
...
@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
...
@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
gen
=
list
()
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
...
@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY
.
format
(
FMHA_BWD_CONVERT_DQ_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_bm0
,
F_bm0
=
self
.
F_bm0
,
F_bn0
=
self
.
F_bn0
,
F_bn0
=
self
.
F_bn0
,
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
...
@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
...
@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
gen
=
list
()
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
b74918bc
...
@@ -44,13 +44,12 @@ FMHA_FWD_KERNEL_BODY="""
...
@@ -44,13 +44,12 @@ FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile_{F_idx}
,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>
,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile_{F_idx}
,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>
,
{F_vlayout}>;
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
...
@@ -282,7 +281,7 @@ class FmhaFwdApiPool:
...
@@ -282,7 +281,7 @@ class FmhaFwdApiPool:
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_hdim
=
hdim
,
F_dtype
=
FWD_
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
...
@@ -301,20 +300,24 @@ class FmhaFwdTileSize:
...
@@ -301,20 +300,24 @@ class FmhaFwdTileSize:
F_bk1
:
int
# tile size along kv gemm unroll
F_bk1
:
int
# tile size along kv gemm unroll
F_bk0max
:
int
# total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_bk0max
:
int
# total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0
:
int
# number of warps for gemm0 along q seqlen
F_rm0
:
int
# number of warps for gemm0 along q seqlen
F_rn0
:
int
# number of warps for gemm0 along k seqlen
F_rn0
:
int
# number of warps for gemm0 along k seqlen
F_rk0
:
int
# number of warps for gemm0 along head dim q (not used)
F_rk0
:
int
# number of warps for gemm0 along head dim q (not used)
F_rm1
:
int
# number of warps for gemm1 along q seqlen
F_rm1
:
int
# number of warps for gemm1 along q seqlen
F_rn1
:
int
# number of warps for gemm1 along head dim v
F_rn1
:
int
# number of warps for gemm1 along head dim v
F_rk1
:
int
# number of warps for gemm1 along k seqlen (not used)
F_rk1
:
int
# number of warps for gemm1 along k seqlen (not used)
F_wm
:
int
# warp size along m (warp size)
F_wm0
:
int
# gemm0 warp size along m
F_wn
:
int
# warp size along n
F_wn0
:
int
# gemm0 warp size along n
F_wk
:
int
# warp size along k
F_wk0
:
int
# gemm0 warp size along k
F_wm1
:
int
# gemm1 warp size along m
F_wn1
:
int
# gemm1 warp size along n
F_wk1
:
int
# gemm1 warp size along k
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
@
property
def
name
(
self
)
->
str
:
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bn1
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk0max
}
"
+
\
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bn1
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk0max
}
"
+
\
f
"_r
{
self
.
F_rm0
}
x
{
self
.
F_rn0
}
x
{
self
.
F_rk0
}
_r
{
self
.
F_rm1
}
x
{
self
.
F_rn1
}
x
{
self
.
F_rk1
}
"
+
\
f
"_r
{
self
.
F_rm0
}
x
{
self
.
F_rn0
}
x
{
self
.
F_rk0
}
_r
{
self
.
F_rm1
}
x
{
self
.
F_rn1
}
x
{
self
.
F_rk1
}
"
+
\
f
"_w
{
self
.
F_wm
}
x
{
self
.
F_wn
}
x
{
self
.
F_wk
}
"
+
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
f
"_w
{
self
.
F_wm0
}
x
{
self
.
F_wn0
}
x
{
self
.
F_wk0
}
_w
{
self
.
F_wm1
}
x
{
self
.
F_wn1
}
x
{
self
.
F_wk1
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
@
dataclass
class
FmhaFwdKernel
:
class
FmhaFwdKernel
:
...
@@ -339,7 +342,7 @@ class FmhaFwdKernel:
...
@@ -339,7 +342,7 @@ class FmhaFwdKernel:
FMHA_FWD_KERNEL_BODY
.
format
(
FMHA_FWD_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
...
@@ -352,9 +355,12 @@ class FmhaFwdKernel:
...
@@ -352,9 +355,12 @@ class FmhaFwdKernel:
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_wm
=
self
.
F_tile
.
F_wm
,
F_wm0
=
self
.
F_tile
.
F_wm0
,
F_wn
=
self
.
F_tile
.
F_wn
,
F_wn0
=
self
.
F_tile
.
F_wn0
,
F_wk
=
self
.
F_tile
.
F_wk
,
F_wk0
=
self
.
F_tile
.
F_wk0
,
F_wm1
=
self
.
F_tile
.
F_wm1
,
F_wn1
=
self
.
F_tile
.
F_wn1
,
F_wk1
=
self
.
F_tile
.
F_wk1
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
...
@@ -409,17 +415,17 @@ class FmhaFwdKernel:
...
@@ -409,17 +415,17 @@ class FmhaFwdKernel:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
#
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32,
96, 4, 1, 1, 4, 1, 1, 32, 32, 16,
32, 32, 16,
-1),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
}
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
return
{
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
)
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
)
,
}
}
else
:
else
:
return
None
return
None
...
@@ -462,6 +468,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -462,6 +468,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# no need lse/dropout kernels
# no need lse/dropout kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
None
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
@@ -469,7 +478,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -469,7 +478,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
gen
=
list
()
gen
=
list
()
api_pool
=
FmhaFwdApiPool
(
mask_impl
)
api_pool
=
FmhaFwdApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
View file @
b74918bc
...
@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
...
@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
FWD_
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
...
@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
...
@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bs
=
self
.
F_tile
.
F_bs
,
F_bs
=
self
.
F_tile
.
F_bs
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bd
=
self
.
F_tile
.
F_bd
,
F_bd
=
self
.
F_tile
.
F_bd
,
...
@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# rope/paged-kv is not supported
# rope/paged-kv is not supported
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
None
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen
=
list
()
gen
=
list
()
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
b74918bc
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
b74918bc
...
@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
...
@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
}
}
// different threshold for different dtype
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
auto
get_elimit
(
ck_tile
::
index_t
/*hdim_q*/
,
ck_tile
::
index_t
/*hdim_v*/
)
auto
get_elimit
(
ck_tile
::
index_t
/*hdim_q*/
,
ck_tile
::
index_t
/*hdim_v*/
)
{
{
double
rtol
=
1e-2
;
double
rtol
=
1e-2
;
...
@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
...
@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
}
}
template
<
>
template
<
>
auto
get_elimit
<
ck_tile
::
b
f16
_t
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
auto
get_elimit
<
FmhaBwdB
f16
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
{
{
double
rtol
=
1e-2
;
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
double
atol
=
1e-2
;
...
@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
...
@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
}
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
...
@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
auto
seqstart_q_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_q
);
const
auto
seqstart_q_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_q
);
const
auto
seqstart_k_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_k
);
const
auto
seqstart_k_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_k
);
using
TypeConfig
=
FmhaBwdTypeConfig
<
DataType
>
;
using
TypeConfig
=
FmhaBwdTypeConfig
<
DataType
Config
>
;
using
QDataType
=
typename
TypeConfig
::
QDataType
;
using
QDataType
=
typename
TypeConfig
::
QDataType
;
using
KDataType
=
typename
TypeConfig
::
KDataType
;
using
KDataType
=
typename
TypeConfig
::
KDataType
;
...
@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
// clang-format on
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
Config
>
(
hdim_q
,
hdim_v
);
bool
dq_cur_pass
=
ck_tile
::
check_err
(
dq_host_result
,
bool
dq_cur_pass
=
ck_tile
::
check_err
(
dq_host_result
,
dq_host_ref
,
dq_host_ref
,
std
::
string
(
"Error: QGrad Incorrect results!"
),
std
::
string
(
"Error: QGrad Incorrect results!"
),
...
@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
...
@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
if
(
data_type
==
"fp16"
)
{
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaBwdFp16
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16"
)
else
if
(
data_type
==
"bf16"
)
{
{
return
run
<
ck_tile
::
b
f16
_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaBwdB
f16
>
(
arg_parser
)
?
0
:
-
2
;
}
}
return
-
3
;
return
-
3
;
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
b74918bc
...
@@ -14,11 +14,19 @@
...
@@ -14,11 +14,19 @@
#include <utility>
#include <utility>
#include <variant>
#include <variant>
struct
FmhaBwdFp16
{
};
struct
FmhaBwdBf16
{
};
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
FmhaBwdTypeConfig
;
struct
FmhaBwdTypeConfig
;
template
<
>
template
<
>
struct
FmhaBwdTypeConfig
<
ck_tile
::
half_t
>
struct
FmhaBwdTypeConfig
<
FmhaBwdFp16
>
{
{
using
QDataType
=
ck_tile
::
half_t
;
using
QDataType
=
ck_tile
::
half_t
;
using
KDataType
=
ck_tile
::
half_t
;
using
KDataType
=
ck_tile
::
half_t
;
...
@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
...
@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
};
};
template
<
>
template
<
>
struct
FmhaBwdTypeConfig
<
ck_tile
::
b
f16
_t
>
struct
FmhaBwdTypeConfig
<
FmhaBwdB
f16
>
{
{
using
QDataType
=
ck_tile
::
bf16_t
;
using
QDataType
=
ck_tile
::
bf16_t
;
using
KDataType
=
ck_tile
::
bf16_t
;
using
KDataType
=
ck_tile
::
bf16_t
;
...
@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments
// create group mode kernel arguments
if
constexpr
(
FmhaBwdDQDKDVKernel
::
kIsGroupMode
)
if
constexpr
(
FmhaBwdDQDKDVKernel
::
kIsGroupMode
)
{
{
return
FmhaBwdDQDKDVKernel
::
MakeKargs
(
args
.
q_ptr
,
return
FmhaBwdDQDKDVKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
do_ptr
,
args
.
do_ptr
,
args
.
d_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
rand_val_ptr
,
args
.
dk_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
dq_acc_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale
,
args
.
scale
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
stride_dbias
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
nhead_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
else
else
{
// create batch mode kernel arguments
{
// create batch mode kernel arguments
return
FmhaBwdDQDKDVKernel
::
MakeKargs
(
args
.
q_ptr
,
return
FmhaBwdDQDKDVKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
do_ptr
,
args
.
do_ptr
,
args
.
d_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
rand_val_ptr
,
args
.
dk_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
dq_acc_ptr
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale
,
args
.
scale
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
stride_dbias
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_randval
,
args
.
batch_stride_do
,
args
.
batch_stride_do
,
args
.
batch_stride_lsed
,
args
.
batch_stride_lsed
,
args
.
batch_stride_dq_acc
,
args
.
batch_stride_dq_acc
,
args
.
batch_stride_dk
,
args
.
batch_stride_dk
,
args
.
batch_stride_dv
,
args
.
batch_stride_dv
,
args
.
batch_stride_dbias
,
args
.
batch_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
}();
}();
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
b74918bc
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
b74918bc
...
@@ -16,11 +16,35 @@
...
@@ -16,11 +16,35 @@
#include <utility>
#include <utility>
#include <variant>
#include <variant>
struct
FmhaFwdFp16
{
};
struct
FmhaFwdBf16
{
};
struct
FmhaFwdFp8
{
};
struct
FmhaFwdBf8
{
};
struct
FmhaFwdFp8Fp16
{
};
struct
FmhaFwdFp8Bf16
{
};
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
FmhaFwdTypeConfig
;
struct
FmhaFwdTypeConfig
;
template
<
>
template
<
>
struct
FmhaFwdTypeConfig
<
ck_tile
::
half_t
>
struct
FmhaFwdTypeConfig
<
FmhaFwdFp16
>
{
{
using
QDataType
=
ck_tile
::
half_t
;
using
QDataType
=
ck_tile
::
half_t
;
using
KDataType
=
ck_tile
::
half_t
;
using
KDataType
=
ck_tile
::
half_t
;
...
@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
...
@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
};
};
template
<
>
template
<
>
struct
FmhaFwdTypeConfig
<
ck_tile
::
b
f16
_t
>
struct
FmhaFwdTypeConfig
<
FmhaFwdB
f16
>
{
{
using
QDataType
=
ck_tile
::
bf16_t
;
using
QDataType
=
ck_tile
::
bf16_t
;
using
KDataType
=
ck_tile
::
bf16_t
;
using
KDataType
=
ck_tile
::
bf16_t
;
...
@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
...
@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
};
};
template
<
>
template
<
>
struct
FmhaFwdTypeConfig
<
ck_tile
::
fp8_t
>
struct
FmhaFwdTypeConfig
<
FmhaFwdFp8
>
{
{
using
QDataType
=
ck_tile
::
fp8_t
;
using
QDataType
=
ck_tile
::
fp8_t
;
using
KDataType
=
ck_tile
::
fp8_t
;
using
KDataType
=
ck_tile
::
fp8_t
;
...
@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
...
@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
};
};
template
<
>
template
<
>
struct
FmhaFwdTypeConfig
<
ck_tile
::
bf8_t
>
struct
FmhaFwdTypeConfig
<
FmhaFwdBf8
>
{
{
using
QDataType
=
ck_tile
::
bf8_t
;
using
QDataType
=
ck_tile
::
bf8_t
;
using
KDataType
=
ck_tile
::
bf8_t
;
using
KDataType
=
ck_tile
::
bf8_t
;
...
@@ -165,6 +189,8 @@ struct fmha_fwd_splitkv_args
...
@@ -165,6 +189,8 @@ struct fmha_fwd_splitkv_args
void
*
block_table_ptr
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
bool
is_gappy
;
// differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
// nullptr.
const
void
*
cache_batch_idx
;
const
void
*
cache_batch_idx
;
...
@@ -173,9 +199,21 @@ struct fmha_fwd_splitkv_args
...
@@ -173,9 +199,21 @@ struct fmha_fwd_splitkv_args
// seqlen_k = kargs.seqlen_k
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// or kargs.seqlen_k_ptr[b]
//
// batch mode (kvcache):
// seqlen_q = kargs.seqlen_q
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
//
// when is_gappy=true:
// seqlen_k = kargs.seqlen_k_ptr[b]
// seqstart_k_ptr[b] now store local offset of each batch
//
// when is_gappy=false:
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
const
void
*
seqlen_k_ptr
;
...
@@ -251,7 +289,7 @@ struct fmha_fwd_appendkv_args
...
@@ -251,7 +289,7 @@ struct fmha_fwd_appendkv_args
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
const
void
*
cache_batch_idx
;
// only used if block_table_ptr is nullptr -> batch mode (kvcache)
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_k
;
...
@@ -278,87 +316,87 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -278,87 +316,87 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
// create group mode kernel arguments
// create group mode kernel arguments
if
constexpr
(
FmhaKernel
::
kIsGroupMode
)
if
constexpr
(
FmhaKernel
::
kIsGroupMode
)
{
{
return
FmhaKernel
::
MakeKargs
(
args
.
q_ptr
,
return
FmhaKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
rand_val_ptr
,
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
o_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale_s
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_p
,
args
.
scale_o
,
args
.
scale_o
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_randval
,
args
.
stride_o
,
args
.
stride_o
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
nhead_stride_o
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
s_randval
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
else
else
{
// create batch mode kernel arguments
{
// create batch mode kernel arguments
return
FmhaKernel
::
MakeKargs
(
args
.
q_ptr
,
return
FmhaKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
rand_val_ptr
,
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
o_ptr
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale_s
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_p
,
args
.
scale_o
,
args
.
scale_o
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_randval
,
args
.
stride_o
,
args
.
stride_o
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
nhead_stride_o
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse
,
args
.
batch_stride_lse
,
args
.
batch_stride_o
,
args
.
batch_stride_o
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
s_randval
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
}();
}();
...
@@ -389,6 +427,10 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -389,6 +427,10 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
num_splits
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
is_gappy
,
args
.
scale_s
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_p
,
args
.
stride_q
,
args
.
stride_q
,
...
@@ -667,7 +709,6 @@ std::string fmha_fwd_splitkv_get_name_();
...
@@ -667,7 +709,6 @@ std::string fmha_fwd_splitkv_get_name_();
template
<
ck_tile
::
index_t
HDim_
,
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
typename
DataType_
,
bool
kIsGroupMode_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN1_
,
ck_tile
::
index_t
kN1_
,
bool
kStoreLse_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant_
,
...
@@ -678,7 +719,6 @@ struct fmha_fwd_splitkv_combine_traits_
...
@@ -678,7 +719,6 @@ struct fmha_fwd_splitkv_combine_traits_
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
...
...
example/ck_tile/01_fmha/utils.hpp
View file @
b74918bc
...
@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
...
@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
std
::
string
k_val
,
std
::
string
k_val
,
std
::
string
k_pad_val
,
std
::
string
k_pad_val
,
ck_tile
::
index_t
seqlen_k_min
=
0
,
ck_tile
::
index_t
seqlen_k_min
=
0
,
bool
use_kvcache
=
false
,
bool
need_append_kvcache
=
false
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
...
@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
...
@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
if
(
1
<
batch
&&
use
_kvcache
)
if
(
1
<
batch
&&
need_append
_kvcache
)
{
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
...
...
example/ck_tile/03_gemm/CMakeLists.txt
View file @
b74918bc
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_
gemm_mem_pipeline
EXCLUDE_FROM_ALL
gemm_mem_pipeline
.cpp
)
add_executable
(
tile_example_
universal_gemm
EXCLUDE_FROM_ALL
universal_gemm
.cpp
)
example/ck_tile/03_gemm/README.md
View file @
b74918bc
...
@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
...
@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j
```
```
This will result in an executable
`build/bin/tile_example_gemm_basic`
This will result in an executable
`build/bin/tile_example_gemm_basic`
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
b74918bc
...
@@ -15,12 +15,13 @@
...
@@ -15,12 +15,13 @@
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadC
=
true
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
...
@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
CShuffleEpilogue
,
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CDataType
,
kPad
A
,
kPad
M
,
kPad
B
,
kPad
N
,
kTilePermute
,
kTilePermute
,
kOutputRank
,
kOutputRank
,
1
,
1
,
...
@@ -65,32 +66,29 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -65,32 +66,29 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner
::
kM
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPad
A
,
kPad
B
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPad
M
,
kPad
N
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPad
A
,
kPad
B
,
kPad
C
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPad
M
,
kPad
N
,
kPad
K
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
args
.
p_b
,
args
.
p_c
,
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel with args:"
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
b74918bc
...
@@ -51,20 +51,6 @@ using BDataType = Types::BDataType;
...
@@ -51,20 +51,6 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
gemm_basic_args
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
...
@@ -89,4 +75,4 @@ auto create_args(int argc, char* argv[])
...
@@ -89,4 +75,4 @@ auto create_args(int argc, char* argv[])
}
}
// host API
// host API
float
gemm_calc
(
gemm_basic_a
rgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
b74918bc
...
@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_warmup
,
int
n_repeat
)
int
n_repeat
)
{
{
gemm_basic_a
rgs
args
;
ck_tile
::
GemmHostA
rgs
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
args
.
k
_
batch
=
kbatch
;
args
.
M
=
M
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
K
=
K
;
...
@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
string
op_name
{
"Gemm
{
MemBoundPipeline}"
}
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
N
*
K
+
sizeof
(
CDataType
)
*
M
*
N
;
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
N
*
K
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std::cout << "
Run
" << op_name << "
kernel
with
M
=
" << M << "
N
=
" << N << "
K
=
" << K
std
::
cout
<<
"Run
Gemm
kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
// TODO: add different init types
// TODO: add different init types
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
...
@@ -164,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -164,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
));
ck_tile
::
reference_gemm_gpu
<
ADataType
,
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout>(
CLayout
>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
);
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
ck_tile
::
hip_check_error
(
hipMemcpy
(
c_m_n_gpu_buf_ref
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
));
ck_tile
::
hip_check_error
(
hipFree
(
d_A
));
ck_tile
::
hip_check_error
(
hipFree
(
d_B
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
...
@@ -202,14 +224,16 @@ int run_gemm_example(int argc, char* argv[])
...
@@ -202,14 +224,16 @@ int run_gemm_example(int argc, char* argv[])
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
}
else if(a_layout == "
C
" && b_layout == "
C
")
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
{
// work.
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// else if(a_layout == "C" && b_layout == "C")
}
// {
else if(a_layout == "
C
" && b_layout == "
R
")
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
{
// }
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// else if(a_layout == "C" && b_layout == "R")
}
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
else
{
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
...
...
Prev
1
2
3
4
5
6
7
8
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment