Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
916daf59
Commit
916daf59
authored
Jan 26, 2025
by
Qianfeng Zhang
Browse files
Use k0_loops small tile load/store to replace the big tile load/store for K
parent
4776c8c0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
177 additions
and
64 deletions
+177
-64
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+5
-23
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+35
-14
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+137
-27
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
916daf59
...
...
@@ -1082,20 +1082,10 @@ struct FmhaFwdKernel
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
sequence
<
false
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}();
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
...
@@ -1147,15 +1137,7 @@ struct FmhaFwdKernel
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
k_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
916daf59
...
...
@@ -154,14 +154,18 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kSubQKHeaddim
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
GetNumVLdsBuffers
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
...
@@ -257,8 +261,18 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
auto
k_tile
=
load_tile
(
k_dram_window
);
using
k_tile_type
=
decltype
(
load_tile
(
k_dram_window
));
statically_indexed_array
<
k_tile_type
,
k0_loops
>
k_tiles
;
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tiles
[
i_k0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
});
move_tile_window
(
k_dram_window
,
{
0
,
-
k0_loops
*
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -301,18 +315,18 @@ struct BlockFmhaPipelineQRKSVSAsync
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
// ensure loading of Q from LDS completely done
block_sync_lds
();
do
{
store_tile
(
k_lds_window
,
k_tile
);
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
i_k0
*
kN0
,
0
>
{},
sequence
<
(
i_k0
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
i_k0
]);
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -322,7 +336,14 @@ struct BlockFmhaPipelineQRKSVSAsync
if
(
i_total_loops
<
num_total_loop
-
1
)
{
move_tile_window
(
k_dram_window
,
{
kN0
,
0
});
k_tile
=
load_tile
(
k_dram_window
);
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tiles
[
i_k0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
});
move_tile_window
(
k_dram_window
,
{
0
,
-
k0_loops
*
kK0
});
}
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -335,8 +356,8 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
k
K
0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
sequence
<
i_k0
*
k
N0
,
0
>
{},
sequence
<
(
i_k0
+
1
)
*
kN0
,
kK0
>
{}));
});
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
916daf59
...
...
@@ -291,6 +291,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using
QXPolicy
=
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
;
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
GetNumKLdsBuffers
()
{
if
constexpr
(
KLoadOnce
)
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
k0_loops
=
BlockFmhaShape
::
kQKHeaddim
/
BlockFmhaShape
::
kK0
;
return
k0_loops
;
}
else
return
1
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
GetNumVLdsBuffers
()
{
...
...
@@ -317,8 +332,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
KLoadOnce
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
KDataType
);
...
...
@@ -382,6 +396,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
}
/*
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
...
...
@@ -400,12 +415,36 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
k_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{},
sequence<1>{}));
return k_lds_block_desc;
}
*/
/*
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{
constexpr index_t SingleKSize = [&]() {
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return SingleKSize;
}
*/
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVSingleSmemElementSpaceSize
()
...
...
@@ -428,6 +467,78 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return
SingleVSize
;
}
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
/*
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<GetKSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
ake_tuple(
make_merge_transform(make_tuple(
number<NumKLdsBuffers>{}, number<kNPerBlock / NPerRow>{},
number<NPerRow>{})), make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{},
number<kKPack>{}))), make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
*/
constexpr
index_t
NumKLdsBuffers
=
GetNumKLdsBuffers
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumKLdsBuffers
>
{},
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
kKPerBlock
*
(
kNPerBlock
+
1
)
>
{},
number
<
(
kNPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumKLdsBuffers
>
{},
number
<
kNPerBlock
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsBlockDescriptor
()
...
...
@@ -532,8 +643,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
KLoadOnce
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
KDataType
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment