Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4776c8c0
"example/vscode:/vscode.git/clone" did not exist on "a4f24233e51854c4b5cb7d75637fa0f235f78f8e"
Commit
4776c8c0
authored
Jan 26, 2025
by
Qianfeng Zhang
Browse files
Use un-rolled gemm for Gemm-0
parent
00fe0752
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
86 deletions
+80
-86
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+8
-17
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+72
-69
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
4776c8c0
...
@@ -330,23 +330,14 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -330,23 +330,14 @@ struct BlockFmhaPipelineQRKSVSAsync
// ensure k is completely updated on LDS
// ensure k is completely updated on LDS
block_sync_lds
();
block_sync_lds
();
// for kQKHeaddim == 96 (kSubQKHeaddim == 128), we need to use k0_loops
if
constexpr
(
kQKHeaddim
==
kSubQKHeaddim
)
{
gemm_0
(
s_acc
,
q
,
k_lds_window
);
}
else
{
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
gemm_0
(
s_acc
,
gemm_0
(
get_slice_tile
(
s_acc
,
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
});
});
}
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
View file @
4776c8c0
...
@@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
...
@@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
sequence
<
0
,
1
>>
{});
sequence
<
0
,
1
>>
{});
}
}
/*
template <typename Problem>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
{
...
@@ -71,13 +72,14 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
...
@@ -71,13 +72,14 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
typename Problem::SaccDataType,
typename Problem::SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<
TileGemmShape<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
BlockGemmK
>
,
sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
BlockGemmK>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr auto warp_gemm = []() {
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
constexpr index_t WarpGemmM =
static_assert
(
WarpGemmM
==
4
||
WarpGemmM
==
16
||
WarpGemmM
==
32
);
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); static_assert(WarpGemmM == 4 ||
WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
...
@@ -118,14 +120,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
...
@@ -118,14 +120,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::SaccDataType,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
decltype
(
warp_gemm
)
>
;
Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
}
*/
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment