Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3b945fc9
Unverified
Commit
3b945fc9
authored
Dec 23, 2024
by
M.Emin Ozturk
Committed by
GitHub
Dec 23, 2024
Browse files
Merge branch 'develop' into gemm_bf16_sk_muozturk
parents
6ef3acec
3d15f364
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1451 additions
and
237 deletions
+1451
-237
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+1
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+24
-18
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+48
-37
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+0
-2
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+34
-19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
example/ck_tile/13_moe_sorting/script/smoke_test.sh
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+34
-19
include/ck/config.h.in
include/ck/config.h.in
+2
-2
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+2
-2
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+1
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+2
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+4
-2
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+37
-19
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+6
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+65
-18
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+126
-47
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
...ock_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
+794
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
...litkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
+226
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+29
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+14
-41
No files found.
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
3b945fc9
...
@@ -119,6 +119,7 @@ PIPELINE_MAP = {
...
@@ -119,6 +119,7 @@ PIPELINE_MAP = {
PIPELINE_ENUM_MAP
=
{
PIPELINE_ENUM_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
"qr"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC"
,
"qr_nwarp_sshuffle"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
}
}
BOOL_MAP
=
{
BOOL_MAP
=
{
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
3b945fc9
...
@@ -44,13 +44,12 @@ FMHA_FWD_KERNEL_BODY="""
...
@@ -44,13 +44,12 @@ FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile_{F_idx}
,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>
,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile_{F_idx}
,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>
,
{F_vlayout}>;
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
...
@@ -306,15 +305,19 @@ class FmhaFwdTileSize:
...
@@ -306,15 +305,19 @@ class FmhaFwdTileSize:
F_rm1
:
int
# number of warps for gemm1 along q seqlen
F_rm1
:
int
# number of warps for gemm1 along q seqlen
F_rn1
:
int
# number of warps for gemm1 along head dim v
F_rn1
:
int
# number of warps for gemm1 along head dim v
F_rk1
:
int
# number of warps for gemm1 along k seqlen (not used)
F_rk1
:
int
# number of warps for gemm1 along k seqlen (not used)
F_wm
:
int
# warp size along m (warp size)
F_wm0
:
int
# gemm0 warp size along m
F_wn
:
int
# warp size along n
F_wn0
:
int
# gemm0 warp size along n
F_wk
:
int
# warp size along k
F_wk0
:
int
# gemm0 warp size along k
F_wm1
:
int
# gemm1 warp size along m
F_wn1
:
int
# gemm1 warp size along n
F_wk1
:
int
# gemm1 warp size along k
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
@
property
def
name
(
self
)
->
str
:
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bn1
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk0max
}
"
+
\
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bn1
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk0max
}
"
+
\
f
"_r
{
self
.
F_rm0
}
x
{
self
.
F_rn0
}
x
{
self
.
F_rk0
}
_r
{
self
.
F_rm1
}
x
{
self
.
F_rn1
}
x
{
self
.
F_rk1
}
"
+
\
f
"_r
{
self
.
F_rm0
}
x
{
self
.
F_rn0
}
x
{
self
.
F_rk0
}
_r
{
self
.
F_rm1
}
x
{
self
.
F_rn1
}
x
{
self
.
F_rk1
}
"
+
\
f
"_w
{
self
.
F_wm
}
x
{
self
.
F_wn
}
x
{
self
.
F_wk
}
"
+
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
f
"_w
{
self
.
F_wm0
}
x
{
self
.
F_wn0
}
x
{
self
.
F_wk0
}
_w
{
self
.
F_wm1
}
x
{
self
.
F_wn1
}
x
{
self
.
F_wk1
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
@
dataclass
class
FmhaFwdKernel
:
class
FmhaFwdKernel
:
...
@@ -352,9 +355,12 @@ class FmhaFwdKernel:
...
@@ -352,9 +355,12 @@ class FmhaFwdKernel:
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_wm
=
self
.
F_tile
.
F_wm
,
F_wm0
=
self
.
F_tile
.
F_wm0
,
F_wn
=
self
.
F_tile
.
F_wn
,
F_wn0
=
self
.
F_tile
.
F_wn0
,
F_wk
=
self
.
F_tile
.
F_wk
,
F_wk0
=
self
.
F_tile
.
F_wk0
,
F_wm1
=
self
.
F_tile
.
F_wm1
,
F_wn1
=
self
.
F_tile
.
F_wn1
,
F_wk1
=
self
.
F_tile
.
F_wk1
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
...
@@ -409,17 +415,17 @@ class FmhaFwdKernel:
...
@@ -409,17 +415,17 @@ class FmhaFwdKernel:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
#
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32,
96, 4, 1, 1, 4, 1, 1, 32, 32, 16,
32, 32, 16,
-1),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
}
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
return
{
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
)
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
)
,
}
}
else
:
else
:
return
None
return
None
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
3b945fc9
...
@@ -39,6 +39,7 @@ K0_MAX_SUBMAX_MAP = {
...
@@ -39,6 +39,7 @@ K0_MAX_SUBMAX_MAP = {
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr_nwarp_sshuffle"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
}
}
...
@@ -50,13 +51,12 @@ namespace {{
...
@@ -50,13 +51,12 @@ namespace {{
template <bool kHasUnevenSplits>
template <bool kHasUnevenSplits>
struct kernel_runner {{
struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile
,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>
,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile
,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>
,
{F_vlayout}>;
{F_vlayout}>;
using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...
@@ -161,9 +161,8 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
...
@@ -161,9 +161,8 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim},
{F_hdim},
{F_bm0},
{F_bn1},
{F_mode},
{F_mode},
{F_bn1},
fmha_trait>;
fmha_trait>;
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
...
@@ -177,9 +176,11 @@ using fmha_epilogue =
...
@@ -177,9 +176,11 @@ using fmha_epilogue =
false, false>>;
false, false>>;
using fmha_kernel =
using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>,
ck_tile::FmhaFwdSplitKVCombineKernel<
fmha_pipeline,
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<
fmha_epilogue>;
fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>,
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
{{
...
@@ -192,7 +193,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
...
@@ -192,7 +193,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}};
}};
}}
}}
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode},
{F_bm0},
{F_bn1},
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
#include <iostream>
#include <iostream>
...
@@ -250,16 +251,25 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
...
@@ -250,16 +251,25 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
return -1;
return -1;
}} else {{
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode},
{F_bm0}/2, {
F_bn1
}/
2, true, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode},
/*
F_bn1
=*/3
2, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
}} else {{
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode},
{F_bm0}/2, {
F_bn1
}/
2, false, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode},
/*
F_bn1
=*/3
2, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
...
@@ -302,7 +312,7 @@ class FmhaFwdSplitKVApiTrait:
...
@@ -302,7 +312,7 @@ class FmhaFwdSplitKVApiTrait:
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
spad
==
't'
:
return
'true'
# always support
if
self
.
spad
==
't'
:
return
'true'
# always support
else
:
return
'true'
else
:
return
'true'
elif
self
.
pipeline_tag
in
[
'qr'
]:
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_nwarp_sshuffle'
]:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bm0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bm0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0'
else
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0'
else
:
assert
False
else
:
assert
False
...
@@ -313,7 +323,7 @@ class FmhaFwdSplitKVApiTrait:
...
@@ -313,7 +323,7 @@ class FmhaFwdSplitKVApiTrait:
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
skpad
==
't'
:
return
f
'a.seqlen_k == 0 || a.seqlen_k %
{
self
.
bn0
}
!= 0'
if
self
.
skpad
==
't'
:
return
f
'a.seqlen_k == 0 || a.seqlen_k %
{
self
.
bn0
}
!= 0'
else
:
return
f
'a.seqlen_k != 0 && a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
return
f
'a.seqlen_k != 0 && a.seqlen_k %
{
self
.
bn0
}
== 0'
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_
fp8
'
]:
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_
nwarp_sshuffle
'
]:
if
self
.
skpad
==
't'
:
return
f
'true /*a.seqlen_k %
{
self
.
bn0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
if
self
.
skpad
==
't'
:
return
f
'true /*a.seqlen_k %
{
self
.
bn0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
assert
False
else
:
assert
False
...
@@ -324,7 +334,7 @@ class FmhaFwdSplitKVApiTrait:
...
@@ -324,7 +334,7 @@ class FmhaFwdSplitKVApiTrait:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
vec
}
== 0'
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
vec
}
== 0'
else
:
assert
False
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_nwarp_sshuffle'
]:
bk0submax
=
K0_MAX_SUBMAX_MAP
[
self
.
bk0max
]
bk0submax
=
K0_MAX_SUBMAX_MAP
[
self
.
bk0max
]
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
bk0submax
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
bk0submax
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
bk0submax
}
== 0'
else
:
return
f
'a.hdim_q %
{
bk0submax
}
== 0'
...
@@ -336,7 +346,7 @@ class FmhaFwdSplitKVApiTrait:
...
@@ -336,7 +346,7 @@ class FmhaFwdSplitKVApiTrait:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
vec
}
== 0'
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
vec
}
== 0'
else
:
assert
False
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_nwarp_sshuffle'
]:
bk0submax
=
K0_MAX_SUBMAX_MAP
[
self
.
bk0max
]
bk0submax
=
K0_MAX_SUBMAX_MAP
[
self
.
bk0max
]
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
bk0submax
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
bk0submax
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
bk0submax
}
== 0'
else
:
return
f
'a.hdim_v %
{
bk0submax
}
== 0'
...
@@ -447,12 +457,11 @@ class FmhaFwdSplitKVApiPool:
...
@@ -447,12 +457,11 @@ class FmhaFwdSplitKVApiPool:
@
dataclass
@
dataclass
class
FmhaFwdSplitKVCombineTileSize
:
class
FmhaFwdSplitKVCombineTileSize
:
F_bm0
:
int
# tile size along q seqlen
F_bn1
:
int
# tile size along v head_dim
F_bn1
:
int
# tile size along v head_dim
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
@
property
def
name
(
self
)
->
str
:
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_
bm0
}
x
{
self
.
F_
bn1
}
"
+
\
return
f
"b
{
self
.
F_bn1
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
@
dataclass
...
@@ -485,9 +494,12 @@ class FmhaFwdSplitKVKernel:
...
@@ -485,9 +494,12 @@ class FmhaFwdSplitKVKernel:
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_wm
=
self
.
F_tile
.
F_wm
,
F_wm0
=
self
.
F_tile
.
F_wm0
,
F_wn
=
self
.
F_tile
.
F_wn
,
F_wn0
=
self
.
F_tile
.
F_wn0
,
F_wk
=
self
.
F_tile
.
F_wk
,
F_wk0
=
self
.
F_tile
.
F_wk0
,
F_wm1
=
self
.
F_tile
.
F_wm1
,
F_wn1
=
self
.
F_tile
.
F_wn1
,
F_wk1
=
self
.
F_tile
.
F_wk1
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
...
@@ -553,7 +565,6 @@ class FmhaFwdSplitKVCombineKernel:
...
@@ -553,7 +565,6 @@ class FmhaFwdSplitKVCombineKernel:
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
FWD_DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn1
=
self
.
F_tile
.
F_bn1
,
F_bn1
=
self
.
F_tile
.
F_bn1
,
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
...
@@ -577,17 +588,17 @@ class FmhaFwdSplitKVCombineKernel:
...
@@ -577,17 +588,17 @@ class FmhaFwdSplitKVCombineKernel:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
FmhaFwdTileSize
(
32
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
32
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
16
,
16
,
16
,
16
,
16
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
16
,
16
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
#
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16,
16, 16, 16,
-1),
'128'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
16
,
16
,
16
,
-
1
),
}
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
return
{
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
)
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
)
,
}
}
else
:
else
:
return
None
return
None
...
@@ -595,17 +606,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
...
@@ -595,17 +606,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
FmhaFwdSplitKVCombineTileSize
(
16
,
16
,
-
1
),
'32'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
32
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
## '96' : FmhaFwdSplitKVCombineTileSize(32,
64,
-1),
#
## '96'
: FmhaFwdSplitKVCombineTileSize(32, -1),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
64
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
128
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
}
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
return
{
'64'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
}
}
else
:
else
:
return
None
return
None
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
3b945fc9
...
@@ -709,7 +709,6 @@ std::string fmha_fwd_splitkv_get_name_();
...
@@ -709,7 +709,6 @@ std::string fmha_fwd_splitkv_get_name_();
template
<
ck_tile
::
index_t
HDim_
,
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
typename
DataType_
,
bool
kIsGroupMode_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN1_
,
ck_tile
::
index_t
kN1_
,
bool
kStoreLse_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant_
,
...
@@ -720,7 +719,6 @@ struct fmha_fwd_splitkv_combine_traits_
...
@@ -720,7 +719,6 @@ struct fmha_fwd_splitkv_combine_traits_
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
3b945fc9
...
@@ -3,18 +3,42 @@
...
@@ -3,18 +3,42 @@
#include "moe_sorting_api.hpp"
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
using ms_problem = \
auto kargs = kernel::MakeKargs(a); \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \
auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \
const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
...
@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
...
@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case
(
6
):
{
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
MOE_SORTING_DISPATCH
(
6
);
}
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
MOE_SORTING_DISPATCH
(
8
);
}
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
MOE_SORTING_DISPATCH
(
10
);
}
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
default:
{
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
...
...
example/ck_tile/13_moe_sorting/script/smoke_test.sh
View file @
3b945fc9
...
@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19
...
@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19
$EXE
-t
=
71
-e
=
11
-k
=
11
$EXE
-t
=
71
-e
=
11
-k
=
11
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
333
-e
=
99
-k
=
13
\ No newline at end of file
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
3b945fc9
...
@@ -3,18 +3,42 @@
...
@@ -3,18 +3,42 @@
#include "fused_moesorting.hpp"
#include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
using ms_problem = \
auto kargs = kernel::MakeKargs(a); \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \
auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \
const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
...
@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
...
@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
case
(
6
):
{
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
MOE_SORTING_DISPATCH
(
6
);
}
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
MOE_SORTING_DISPATCH
(
8
);
}
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
MOE_SORTING_DISPATCH
(
10
);
}
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
default:
{
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
...
...
include/ck/config.h.in
View file @
3b945fc9
...
@@ -115,8 +115,8 @@
...
@@ -115,8 +115,8 @@
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif
#endif
#ifndef
D
CK_USE_OCP_FP8
#ifndef CK_USE_OCP_FP8
#cmakedefine
D
CK_USE_OCP_FP8 @
D
CK_USE_OCP_FP8@
#cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif
#endif
#ifndef CK_USE_FNUZ_FP8
#ifndef CK_USE_FNUZ_FP8
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
3b945fc9
...
@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
...
@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert
(
static_assert
(
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
3b945fc9
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
{
...
...
include/ck_tile/ops/fmha.hpp
View file @
3b945fc9
...
@@ -29,6 +29,8 @@
...
@@ -29,6 +29,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
3b945fc9
...
@@ -71,7 +71,8 @@ struct FmhaFwdKernel
...
@@ -71,7 +71,8 @@ struct FmhaFwdKernel
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
g0wt
=
typename
bfs
::
Gemm0WarpTile
;
using
g1wt
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _SS_ std::string
#define _TS_ std::to_string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
auto
pn
=
[
&
]
()
{
...
@@ -88,7 +89,8 @@ struct FmhaFwdKernel
...
@@ -88,7 +89,8 @@ struct FmhaFwdKernel
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
3b945fc9
...
@@ -8,9 +8,11 @@ namespace ck_tile {
...
@@ -8,9 +8,11 @@ namespace ck_tile {
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVCombineKernel
struct
FmhaFwdSplitKVCombineKernel
{
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
index_t
kNumWarps
=
FmhaPipeline
::
kNumWarps
;
static
constexpr
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
static_assert
(
kBlockPerCu
>
0
);
static_assert
(
kBlockPerCu
>
0
);
...
@@ -50,8 +52,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -50,8 +52,7 @@ struct FmhaFwdSplitKVCombineKernel
return
return
_SS_
(
"fmha_fwd_splitkv_combine_d"
)
+
_TS_
(
FmhaPipeline
::
kHeadDimV
)
+
"_"
+
_SS_
(
t2s
<
ODataType
>::
name
)
+
_SS_
(
"fmha_fwd_splitkv_combine_d"
)
+
_TS_
(
FmhaPipeline
::
kHeadDimV
)
+
"_"
+
_SS_
(
t2s
<
ODataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
FmhaPipeline
::
kM0
)
+
"x"
+
"b"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
...
@@ -339,37 +340,56 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -339,37 +340,56 @@ struct FmhaFwdSplitKVCombineKernel
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
number
<
1
>
{});
// read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
const
auto
o_acc_dram_view
=
pad_tensor_view
(
const
auto
o_acc_dram_view
=
pad_tensor_view
(
o_acc_dram_naive
,
o_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
make_tuple
(
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
number
<
kNumWarps
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
true
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_num_splits
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
0
>
{}];
const
index_t
padded_seqlen_q
=
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
const
index_t
num_m_tiles
=
integer_divide_floor
(
padded_seqlen_q
,
FmhaPipeline
::
kM0
);
// transform tensor view by following steps, given shape: (padded_num_splits,
// padded_seqlen_q, padded_hdim_v)
// 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
// 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
// 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
auto
transposed
=
transform_tensor_view
(
o_acc_dram_view
,
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_tuple
(
make_pass_through_transform
(
padded_num_splits
),
make_unmerge_transform
(
make_tuple
(
num_m_tiles
,
FmhaPipeline
::
kM0
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{},
sequence
<
3
>
{}));
return
transform_tensor_view
(
transposed
,
make_tuple
(
make_merge_transform
(
make_tuple
(
num_m_tiles
,
padded_num_splits
,
FmhaPipeline
::
kM0
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}();
}();
auto
lse_acc_dram_window
=
make_tile_window
(
auto
lse_acc_dram_window
=
make_tile_window
(
lse_acc_dram
,
lse_acc_dram
,
[
&
]()
{
make_tuple
(
number
<
FmhaPipeline
::
kMaxSplits
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
return
make_tuple
(
number
<
FmhaPipeline
::
kMaxSplits
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
}(),
{
0
,
i_m0
});
{
0
,
i_m0
});
const
index_t
padded_num_splits
=
integer_divide_ceil
(
kargs
.
num_splits
,
kNumWarps
)
*
kNumWarps
;
auto
o_acc_dram_window
=
make_tile_window
(
auto
o_acc_dram_window
=
make_tile_window
(
o_acc_dram
,
o_acc_dram
,
[
&
]()
{
make_tuple
(
number
<
kNumWarps
*
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{});
{
i_tile_m
*
padded_num_splits
*
FmhaPipeline
::
kM0
,
i_n1
});
}(),
{
i_m0
,
i_n1
});
// LSE DRAM window
// LSE DRAM window
auto
lse_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
auto
lse_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
...
@@ -410,7 +430,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -410,7 +430,6 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
num_splits
,
kargs
.
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
else
else
...
@@ -419,7 +438,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -419,7 +438,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
o_acc_dram_window
,
lse_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
num_splits
,
kargs
.
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
3b945fc9
...
@@ -45,6 +45,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -45,6 +45,7 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
...
@@ -67,7 +68,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -67,7 +68,8 @@ struct FmhaFwdSplitKVKernel
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
g0wt
=
typename
bfs
::
Gemm0WarpTile
;
using
g1wt
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _SS_ std::string
#define _TS_ std::to_string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
auto
pn
=
[
&
]
()
{
...
@@ -84,11 +86,12 @@ struct FmhaFwdSplitKVKernel
...
@@ -84,11 +86,12 @@ struct FmhaFwdSplitKVKernel
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
)
+
(
kIsPagedKV
?
"_pagedkv"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
)
+
(
kIsPagedKV
?
"_pagedkv"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
// clang-format on
// clang-format on
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
3b945fc9
...
@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kHeadDimV
=
Problem
::
kHeadDimV
;
static
constexpr
index_t
kHeadDimV
=
Problem
::
kHeadDimV
;
...
@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
LSEElementFunction
&
lse_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
num_splits
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
// lse_acc tile in LDS
// lse_acc tile in LDS
...
@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto
lse_acc_tile
=
load_tile
(
lse_acc_dram_window
);
auto
lse_acc_tile
=
load_tile
(
lse_acc_dram_window
);
store_tile
(
lse_acc_lds_write_window
,
lse_acc_tile
);
store_tile
(
lse_acc_lds_write_window
,
lse_acc_tile
);
block_sync_lds
();
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
// and fill up -INF values outside the [kM0, num_splits] region.
{
{
...
@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}
}
});
});
}
}
block_sync_lds
();
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse_logsum
));
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse_logsum
));
}
}
auto
o_acc_dist
=
Policy
::
template
MakeOaccDramTileDistribution
<
Problem
>();
auto
o_acc_
4_
dist
=
Policy
::
template
MakeOacc
4
DramTileDistribution
<
Problem
>();
auto
o_acc_dram_window
=
auto
o_acc_
4_
dram_window
=
make_tile_window
(
o_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
o_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
o_acc_dram_block_window_tmp
.
get_window_lengths
(),
o_acc_dram_block_window_tmp
.
get_window_lengths
(),
o_acc_dram_block_window_tmp
.
get_window_origin
(),
o_acc_dram_block_window_tmp
.
get_window_origin
(),
o_acc_dist
);
o_acc_4_dist
);
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
const
index_t
padded_seqlen_q
=
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
kM0
;
// shape=[4 * KM0, kN1]
auto
o_acc_4
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_4_dist
);
clear_tile
(
o_acc_4
);
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
const
index_t
padded_num_splits
=
integer_divide_ceil
(
num_splits
,
kNumWarps
)
*
kNumWarps
;
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
// each warp handles a [KM0, kN1] tile
for
(
index_t
split_start
=
0
;
split_start
<
padded_num_splits
;
split_start
+=
kNumWarps
)
{
{
auto
o_tile
=
load_tile
(
o_acc_dram_window
);
auto
o_tile
=
load_tile
(
o_acc_4_dram_window
);
const
index_t
i_split
=
split_start
+
get_warp_id
();
const
index_t
row_start
=
kM0
*
get_warp_id
();
{
{
constexpr
auto
spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
constexpr
auto
spans
=
decltype
(
o_acc
_4
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
o_acc
.
get_tile_distribution
(),
i_j_idx
);
o_acc
_4
.
get_tile_distribution
(),
i_j_idx
);
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
const
LSEDataType
lse_scale
=
lse_acc_lds
(
row
,
i_split
);
const
LSEDataType
lse_scale
=
lse_acc_lds
(
row
-
row_start
,
i_split
);
o_acc
(
i_j_idx
)
+=
lse_scale
*
o_tile
(
i_j_idx
);
o_acc
_4
(
i_j_idx
)
+=
lse_scale
*
o_tile
(
i_j_idx
);
});
});
});
});
}
}
move_tile_window
(
o_acc_dram_window
,
{
padded_seqlen_q
,
0
});
move_tile_window
(
o_acc_4_dram_window
,
{
kNumWarps
*
kM0
,
0
});
}
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
OaccDataType
*
o_acc_4_lds_ptr
=
static_cast
<
OaccDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeLSEacc
<
Problem
>()));
{
auto
o_acc_4_lds_window
=
[
&
]()
{
auto
desc
=
Policy
::
template
MakeOacc4LdsBlockDescriptor
<
Problem
>();
auto
view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
o_acc_4_lds_ptr
,
desc
);
return
make_tile_window
(
view
,
desc
.
get_lengths
(),
{
0
,
0
});
}();
store_tile
(
o_acc_4_lds_window
,
o_acc_4
);
}
}
auto
o_acc_dist
=
Policy
::
template
MakeOaccDramTileDistribution
<
Problem
>();
auto
o_acc_4_lds_window
=
[
&
]()
{
auto
desc
=
Policy
::
template
MakeOacc4LdsBlockDescriptor
<
Problem
>();
auto
view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
o_acc_4_lds_ptr
,
desc
);
return
make_tile_window
(
view
,
desc
.
get_lengths
(),
{
0
,
0
},
o_acc_dist
);
}();
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
static_for
<
0
,
kNumWarps
,
1
>
{}([
&
](
auto
)
{
auto
o_acc_in
=
load_tile
(
o_acc_4_lds_window
);
{
constexpr
auto
spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
+=
o_acc_in
(
i_j_idx
);
});
});
}
move_tile_window
(
o_acc_4_lds_window
,
{
kM0
,
0
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
return
o_acc
;
...
@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
num_splits
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
return
operator
()(
lse_acc_dram_block_window
,
return
operator
()(
lse_acc_dram_block_window
,
...
@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
num_splits
,
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
};
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
3b945fc9
...
@@ -10,23 +10,38 @@ namespace ck_tile {
...
@@ -10,23 +10,38 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
{
template
<
index_t
BlockSize
,
index_t
M
,
index_t
N
,
typename
DataType
>
template
<
index_t
NumWarps
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMaxNumWarpsForTile
()
{
static_assert
(
NumWarps
==
1
||
NumWarps
==
2
||
NumWarps
==
4
);
constexpr
index_t
ElemPerThread
=
(
M
*
N
)
/
(
NumWarps
*
get_warp_size
());
if
constexpr
(
0
<
ElemPerThread
)
{
return
NumWarps
;
}
else
{
// try dividing tile by smaller # of warps
return
GetMaxNumWarpsForTile
<
NumWarps
/
2
,
M
,
N
,
DataType
>
();
}
}
template
<
index_t
NumWarps
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeForTile
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeForTile
()
{
{
constexpr
index_t
PixelsPerThread
=
(
M
*
N
)
/
BlockSize
;
constexpr
index_t
MaxNumWarps
=
GetMaxNumWarpsForTile
<
NumWarps
,
M
,
N
,
DataType
>
();
static_assert
(
0
<
PixelsPerThread
);
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
constexpr
index_t
ElemPerThread
=
(
M
*
N
)
/
(
MaxNumWarps
*
get_warp_size
());
constexpr
index_t
NPerThread
=
min
(
MaxNPerThread
,
PixelsPerThread
);
return
NPerThread
;
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
return
min
(
MaxNPerThread
,
ElemPerThread
);
}
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
{
{
return
GetVectorSizeForTile
<
Problem
::
k
BlockSize
,
return
GetVectorSizeForTile
<
Problem
::
k
NumWarps
,
Problem
::
kMaxSplits
,
Problem
::
kMaxSplits
,
Problem
::
kM0
,
Problem
::
kM0
,
typename
Problem
::
LSEDataType
>
();
typename
Problem
::
LSEDataType
>
();
...
@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
LSEacc
()
{
{
return
sizeof
(
typename
Problem
::
LSEDataType
)
*
return
sizeof
(
typename
Problem
::
LSEDataType
)
*
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeOacc4
()
{
return
sizeof
(
typename
Problem
::
OaccDataType
)
*
MakeOacc4LdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
GetSmemSizeLSEacc
<
Problem
>
()
+
GetSmemSizeOacc4
<
Problem
>
();
}
// shape=[kMaxSplits, kM0]
// shape=[kMaxSplits, kM0]
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
{
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
MaxNumWarps
=
GetMaxNumWarpsForTile
<
Problem
::
kNumWarps
,
kNPerBlock
,
kMPerBlock
,
LSEDataType
>
();
constexpr
index_t
Replicate
=
Problem
::
kNumWarps
/
MaxNumWarps
;
constexpr
index_t
NPerThread
=
constexpr
index_t
NPerThread
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
MaxNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
k
NumWarps
*
MThreadsPerWarp
);
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
Max
NumWarps
*
MThreadsPerWarp
);
static_assert
(
MPerThread
*
MaxNumWarps
*
MThreadsPerWarp
==
kMPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MPerThread
*
kNumWarps
*
MThreadsPerWarp
==
kMPerBlock
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
Replicate
>
,
tuple
<
sequence
<
MPerThread
,
k
NumWarps
,
MThreadsPerWarp
>
,
tuple
<
sequence
<
MPerThread
,
Max
NumWarps
,
MThreadsPerWarp
>
,
sequence
<
NThreads
,
NPerThread
>>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
1
>>
{});
}
}
...
@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
NPack
>
{},
number
<
1
>
{});
number
<
1
>
{});
constexpr
auto
lse_acc_lds_block_desc
=
transform_tensor_descriptor
(
constexpr
auto
lse_acc_lds_block_desc
=
transform_tensor_descriptor
(
...
@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
NPack
>
{},
number
<
1
>
{});
number
<
1
>
{});
constexpr
auto
lse_acc_t_lds_block_desc
=
transform_tensor_descriptor
(
constexpr
auto
lse_acc_t_lds_block_desc
=
transform_tensor_descriptor
(
...
@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return
lse_acc_t_lds_block_desc
;
return
lse_acc_t_lds_block_desc
;
}
}
// 3d + padding, shape=[4 * kM0, kN1]
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LSEaccRegTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
Oacc4LdsBlockDescriptor
()
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
4
*
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
o_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
o_acc_t_lds_block_desc
=
transform_tensor_descriptor
(
o_acc_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kNPerBlock
/
NPack
,
NPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
o_acc_t_lds_block_desc
;
}
// shape=[kM0, kMaxSplits]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccRegTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NThreads
=
4
;
constexpr
index_t
MaxNThreads
=
8
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
NThreads
=
min
(
kNPerBlock
,
MaxNThreads
);
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
1
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MThreads
=
kMPerBlock
/
MPerThread
;
constexpr
index_t
MWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MaxNumWarps
=
(
MThreads
*
NThreads
)
/
get_warp_size
();
constexpr
index_t
Replicate
=
Problem
::
kNumWarps
/
MaxNumWarps
;
static_assert
(
MaxNumWarps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MWarps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
Replicate
>
,
sequence
<
1
>
,
tuple
<
sequence
<
MaxNumWarps
,
MThreadPerWarp
,
MPerThread
>
,
tuple
<
sequence
<
MWarps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
sequence
<
2
,
1
>>
{});
}
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
// direction
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOacc4DramTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
// real kMPerBlock we want is (4 * kM0)
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
static_assert
(
get_warp_size
()
<=
kMPerBlock
*
kNPerBlock
);
constexpr
index_t
M1
=
1
;
// compose encoding base on 1 warp
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
kNPerBlock
/
N0
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
4
,
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
2
>
,
sequence
<
3
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
static_assert
(
kBlockSize
<=
kMPerBlock
*
kNPerBlock
);
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
0 → 100644
View file @
3b945fc9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kIsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentOacc
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOacc
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kQKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
kQKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr_nwarp_sshuffle"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEaccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KPageBlockNavigator
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VPageBlockNavigator
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kSubQKHeaddim
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN0
==
KDramBlockWindowLengths
{}[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowLengths
{}[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowLengths
{}[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowLengths
{}[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
max
(
Policy
::
template
GetSmemSizeQ
<
Problem
>(),
Policy
::
template
GetSmemSizeK
<
Problem
>())),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// S tile in LDS
auto
s_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
SaccDataType
*>
(
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
max
(
Policy
::
template
GetSmemSizeQ
<
Problem
>(),
Policy
::
template
GetSmemSizeK
<
Problem
>())),
Policy
::
template
MakeSLdsBlockDescriptor
<
Problem
>());
auto
s_write_lds_window
=
make_tile_window
(
s_lds
,
Policy
::
template
MakeSLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
auto
s_read_lds_window
=
make_tile_window
(
s_lds
,
Policy
::
template
MakeSLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
template
MakeSRegTileDistribution
<
Problem
>());
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// load Q here, will store Q into LDS to maximize throughput
auto
origin_q
=
load_tile
(
q_dram_window
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
auto
o_acc
=
OaccBlockTileType
{};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
o_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
// init M, L
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
logical_seqlen_k_start
,
logical_seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
const
index_t
logical_num_total_loop
=
integer_divide_ceil
(
logical_seqlen_k_end
-
logical_seqlen_k_start
,
kN0
);
if
(
logical_num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
if
(
get_thread_local_1d_id
()
<
kM0
)
{
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
}
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return
o_acc
;
}
}
const
index_t
physical_seqlen_k_start
=
logical_seqlen_k_start
+
kv_l2p_offset
;
const
index_t
physical_seqlen_k_end
=
logical_seqlen_k_end
+
kv_l2p_offset
;
// make sure the first tile is completely located in page-block (page-block size should be
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const
index_t
aligned_physical_seqlen_k_start
=
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
]
{
if
constexpr
(
kIsPagedKV
)
{
return
kN0
*
integer_divide_floor
(
physical_seqlen_k_start_
,
kN0
);
}
else
{
return
physical_seqlen_k_start_
;
}
}();
const
index_t
num_total_loop
=
integer_divide_ceil
(
physical_seqlen_k_end
-
aligned_physical_seqlen_k_start
,
kN0
);
auto
[
i_page_block_k
,
k_dram_block_window
]
=
k_page_block_navigator
.
make_tile_window
(
k_dram_block_window_lengths
,
{
aligned_physical_seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
logical_seqlen_k_start
-
(
physical_seqlen_k_start
-
aligned_physical_seqlen_k_start
)},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
v_dram_block_window_lengths
,
{
0
,
aligned_physical_seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// store Q into LDS
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_lds_window_for_store
=
make_tile_window
(
q_lds
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
store_tile
(
q_lds_window_for_store
,
origin_q
);
__builtin_amdgcn_sched_barrier
(
0
);
// load Q from LDS
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_lds_window_for_load
=
make_tile_window
(
q_lds
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
template
MakeQRegTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
block_sync_lds
();
auto
q
=
load_tile
(
q_lds_window_for_load
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
,
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load the first tile of the first iteration and store to LDS
auto
k_block_tile
=
load_tile
(
k_dram_window
);
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
// load the second tile of the first iteration
k_block_tile
=
load_tile
(
k_dram_window
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
// LDS write i + 1
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kM0
,
(
k0_loops
-
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window
);
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
// position_encoding accept only logical coordinates, do conversion here
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
-
kv_l2p_offset
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in first/last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
,
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
if
constexpr
(
kIsPagedKV
)
{
return
col
<
physical_seqlen_k_start_
||
physical_seqlen_k_end_
<=
col
;
}
else
{
return
physical_seqlen_k_end_
<=
col
;
}
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
// mask accept only logical coordinates, do conversion here
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{})
-
kv_l2p_offset
,
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
-
kv_l2p_offset
);
});
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// load the first tile for next iteration
if
(
i_total_loops
<
num_total_loop
-
1
)
{
// move K tile windows
i_page_block_k
=
k_page_block_navigator
.
move_tile_window
(
i_page_block_k
,
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
make_tile_window
(
k_dram_block_window
,
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window
// laod the first tile of the first iteration and store to LDS
k_block_tile
=
load_tile
(
k_dram_window
);
}
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
// shuffle through LDS so that the tile layout is consistent with required by Gemm1
store_tile
(
s_write_lds_window
,
s
);
block_sync_lds
();
auto
s_new
=
load_tile
(
s_read_lds_window
);
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s_new
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s_new
.
get_tile_distribution
());
// Pcompute{j}
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s_new
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s_new
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s_new
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_prefetch
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
i_page_block_v
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v
,
v_dram_window
,
{
0
,
kK1
});
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
,
&
i_page_block_v_
=
i_page_block_v
,
&
v_dram_window_
=
v_dram_window
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window_
);
// load next v
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
i_page_block_v_
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v_
,
v_dram_window_
,
{
0
,
kK1
});
});
}
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
k1_loops
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
}
__builtin_amdgcn_sched_barrier
(
0
);
// load the first tile for next iteration
if
(
i_total_loops
<
num_total_loop
-
1
)
{
// store the first tile for next iteration to LDS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
}
}
while
(
++
i_total_loops
<
num_total_loop
);
if
constexpr
(
kStoreLSE
)
{
// store lse acc
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_acc_spans
=
decltype
(
lse_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_acc_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
if
(
get_thread_local_1d_id
()
<
kM0
)
{
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
void
*
smem_ptr
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_lengths
,
k_page_block_navigator
,
identity
{},
v_dram_block_window_lengths
,
v_page_block_navigator
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
lse_acc_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
i_split
,
mask
,
position_encoding
,
scale_s
,
kv_l2p_offset
,
smem_ptr
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
0 → 100644
View file @
3b945fc9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
{
using
BasePolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
// this should align with MakeQDramTileDistribution()
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
return
min
(
ElemPerThread
,
MaxVectorSize
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOacc
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
OaccDataType
));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kMaxVecLoad
=
min
(
ElemPerThread
,
MaxVectorSize
);
constexpr
index_t
KPerThread
=
kMaxVecLoad
;
constexpr
index_t
KThreads
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
KThreads
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
MThreadPerWarp
*
NumWarps
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
NumWarps
,
MThreadPerWarp
>
,
sequence
<
KThreads
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegTileDistribution
()
{
return
BasePolicy
::
template
MakeQDramTileDistribution
<
Problem
,
BlockGemm
>();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
// TODO: this is for 3d layout
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
QDataType
));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kKPack
=
min
(
ElemPerThread
,
GetSmemKPackQ
<
Problem
>
());
constexpr
auto
q_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kMPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
q_lds_block_desc
=
transform_tensor_descriptor
(
q_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
q_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemNPackS
()
{
using
SDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
SDataType
));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPack
=
GetSmemNPackS
<
Problem
>
();
constexpr
auto
s_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
kNPack
>
{},
number
<
kMPerBlock
>
{},
number
<
kNPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
kNPack
>
{},
number
<
kNPack
>
{},
number
<
1
>
{}),
number
<
kNPack
>
{},
number
<
1
>
{});
constexpr
auto
s_lds_block_desc
=
transform_tensor_descriptor
(
s_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
kNPack
>
{},
number
<
kNPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
s_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSRegTileDistribution
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetKVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static_assert
(
MWarp
==
1
,
"Check failed!"
);
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kTileK
=
Problem
::
BlockFmhaShape
::
kN0
;
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
constexpr
index_t
K3
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K2
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
kKPerBlock
/
(
K2
*
K3
);
constexpr
index_t
K0
=
kTileK
/
kKPerBlock
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
auto
s2_block_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
2
,
2
,
2
>
,
sequence
<
0
,
0
,
1
,
3
>>
{};
constexpr
auto
s2_block_dstr
=
make_static_tile_distribution
(
s2_block_dstr_encoding
);
return
s2_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
return
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
QDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
{
return
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
KDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeV
()
{
return
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
VDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeS
()
{
return
MakeSLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
SaccDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
max
(
GetSmemSizeQ
<
Problem
>
(),
GetSmemSizeK
<
Problem
>
())
+
max
(
GetSmemSizeV
<
Problem
>
(),
GetSmemSizeS
<
Problem
>
());
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
3b945fc9
...
@@ -106,28 +106,43 @@ struct BlockFmhaFwdSplitKVPipelineProblem
...
@@ -106,28 +106,43 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
// extract tile size attributes to remove dependency on traits
template
<
typename
OaccDataType_
,
ck_tile
::
index_t
kN1_
>
struct
BlockFmhaSplitKVCombinePipelineTileSizes
{
static
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
OaccDataType_
);
static
constexpr
index_t
kN1
=
kN1_
;
static
constexpr
index_t
NThreads
=
kN1
/
MaxVectorSize
;
static
constexpr
index_t
kM0
=
get_warp_size
()
/
NThreads
;
// MThreadPerWarp
};
template
<
typename
LSEDataType_
,
template
<
typename
LSEDataType_
,
typename
OaccDataType_
,
typename
OaccDataType_
,
typename
ODataType_
,
typename
ODataType_
,
index_t
HeadDimV_
,
index_t
HeadDimV_
,
index_t
kM0_
,
index_t
kN1_
,
bool
kIsGroupMode_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kN1_
,
typename
Traits_
>
typename
Traits_
>
struct
BlockFmhaSplitKVCombinePipelineProblem
struct
BlockFmhaSplitKVCombinePipelineProblem
:
BlockFmhaSplitKVCombinePipelineTileSizes
<
OaccDataType_
,
kN1_
>
{
{
using
BaseType
=
BlockFmhaSplitKVCombinePipelineTileSizes
<
OaccDataType_
,
kN1_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kNumWarps
=
kM0_
/
(
get_warp_size
()
/
4
);
static_assert
(
std
::
is_same_v
<
LSEDataType
,
OaccDataType
>
);
static
constexpr
index_t
kBlockSize
=
kNumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kN1
=
kN1_
;
using
BaseType
::
kM0
;
using
BaseType
::
kN1
;
static_assert
(
kN1
<=
kHeadDimV
&&
kHeadDimV
%
kN1
==
0
);
// attributes from traits
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
@@ -136,6 +151,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem
...
@@ -136,6 +151,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kMaxSplits
=
Traits
::
kMaxSplits
;
static
constexpr
index_t
kMaxSplits
=
Traits
::
kMaxSplits
;
static_assert
(
8
<=
kMaxSplits
);
static
constexpr
index_t
kNumWarps
=
4
;
// always use 4 warps for each workgroup
static
constexpr
index_t
kBlockSize
=
kNumWarps
*
get_warp_size
();
static_assert
(
get_warp_size
()
<=
(
kM0
*
kMaxSplits
)
&&
(
kM0
*
kMaxSplits
)
%
get_warp_size
()
==
0
);
};
};
template
<
typename
QDataType_
,
template
<
typename
QDataType_
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
3b945fc9
...
@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
{
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
return
min
(
MaxVectorSize
,
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
);
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
return
BlockGemm
::
template
MakeABlockTileDistribution
<
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
Problem
::
BlockFmhaShape
::
kM0
,
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K0
=
kKPerBlock
/
(
K1
*
K2
);
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
else
{
static_assert
(
MWarp
==
1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
16
||
WarpGemmM
==
32
);
static_assert
(
WarpGemmM
==
4
||
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
...
@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
{
if
constexpr
(
WarpGemmM
==
32
)
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
//
WarpGemmM == 16
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaF16F16F32M4N64K16
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
...
@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
{
if
constexpr
(
WarpGemmM
==
32
)
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
//
WarpGemmM == 16
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaBf16Bf16F32M4N64K16
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment