Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
21dc4596
Commit
21dc4596
authored
Jan 25, 2025
by
Qianfeng Zhang
Browse files
Remove KLoadOnce and use NuPrefetchK > 1
parent
00fe0752
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
204 additions
and
112 deletions
+204
-112
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+5
-23
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
...litkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
+2
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
+1
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+0
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+86
-39
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+11
-14
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
+1
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+0
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
+1
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+97
-30
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
21dc4596
...
@@ -1082,20 +1082,10 @@ struct FmhaFwdKernel
...
@@ -1082,20 +1082,10 @@ struct FmhaFwdKernel
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
pad_tensor_view
(
{
k_dram_naive
,
return
pad_tensor_view
(
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
k_dram_naive
,
sequence
<
false
,
kPadHeadDimQ
>
{});
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
sequence
<
false
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}();
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
@@ -1147,15 +1137,7 @@ struct FmhaFwdKernel
...
@@ -1147,15 +1137,7 @@ struct FmhaFwdKernel
{
i_m0
,
0
});
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window
=
make_tile_window
(
k_dram
,
k_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tile_window
(
v_dram
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
View file @
21dc4596
...
@@ -14,10 +14,12 @@ namespace ck_tile {
...
@@ -14,10 +14,12 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopy = */
false
,
/* AsyncCopy = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
using
BasePolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
using
BasePolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopy = */
false
,
/* AsyncCopy = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
;
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
View file @
21dc4596
...
@@ -12,6 +12,7 @@ namespace ck_tile {
...
@@ -12,6 +12,7 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
struct
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopy = */
false
,
/* AsyncCopy = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
21dc4596
...
@@ -35,9 +35,6 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -35,9 +35,6 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
false
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
21dc4596
...
@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
true
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
@@ -154,16 +151,18 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -154,16 +151,18 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kSubQKHeaddim
==
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
"wrong!"
);
constexpr
auto
NumKLdsBuffers
=
Policy
::
template
GetNumKLdsBuffers
<
Problem
>();
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
GetNumVLdsBuffers
<
Problem
>();
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
GetNumVLdsBuffers
<
Problem
>();
static_assert
(
NumKLdsBuffers
>=
2
,
"At least two LDS buffers needed for K"
);
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
...
@@ -181,8 +180,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -181,8 +180,8 @@ struct BlockFmhaPipelineQRKSVSAsync
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
smem_ptr
);
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
auto
k_lds_window
=
make_tile_window
(
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kSubQKHeaddim
>
{}
),
{
0
,
0
});
k_lds
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>().
get_lengths
(
),
{
0
,
0
});
// V tile in LDS
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
...
@@ -258,7 +257,12 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -258,7 +257,12 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window
.
get_window_origin
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
// load
auto
k_tile
=
load_tile
(
k_dram_window
);
// prefetch two K tiles
auto
k_tile_0
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
auto
k_tile_1
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -299,7 +303,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -299,7 +303,6 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// prefetch K tile
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
...
@@ -310,42 +313,76 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -310,42 +313,76 @@ struct BlockFmhaPipelineQRKSVSAsync
// ensure loading of Q from LDS completely done
// ensure loading of Q from LDS completely done
block_sync_lds
();
block_sync_lds
();
do
__builtin_amdgcn_sched_barrier
(
0
);
{
store_tile
(
k_lds_window
,
k_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
// store first K tile to LDS
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tile_0
);
do
{
// STAGE 1, QK gemm
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
);
// initialize C
if
(
i_total_loops
<
num_total_loop
-
1
)
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
{
if
constexpr
(
i_k0
>
0
&&
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
kN0
,
0
});
{
k_tile
=
load_tile
(
k_dram_window
);
if
constexpr
(
i_k0
%
2
==
1
)
}
k_tile_0
=
load_tile
(
k_dram_window
);
else
k_tile_1
=
load_tile
(
k_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
};
__builtin_amdgcn_sched_barrier
(
0
);
// ensure K data needed by this gemm iteration completely available on LDS
block_sync_lds
();
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
((
i_k0
+
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
(((
i_k0
+
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
// store K data needed by next gemm iteration to LDS
if
constexpr
(
i_k0
%
2
==
0
)
store_tile
(
k_lds_window_tmp
,
tile_elementwise_in
(
k_element_func
,
k_tile_1
));
else
store_tile
(
k_lds_window_tmp
,
tile_elementwise_in
(
k_element_func
,
k_tile_0
));
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
(
i_k0
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
((
i_k0
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{}));
__builtin_amdgcn_sched_barrier
(
0
);
});
// ensure k is completely updated on LDS
block_sync_lds
();
block_sync_lds
();
// for kQKHeaddim == 96 (kSubQKHeaddim == 128), we need to use k0_loops
gemm_0
(
s_acc
,
if
constexpr
(
kQKHeaddim
==
kSubQKHeaddim
)
get_slice_tile
(
{
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
gemm_0
(
s_acc
,
q
,
k_lds_window
);
get_slice_tile
(
k_lds_window
,
}
sequence
<
((
k0_loops
-
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
else
sequence
<
(((
k0_loops
-
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{}));
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
if
(
i_total_loops
<
num_total_loop
-
1
)
{
{
move_tile_window
(
k_dram_window
,
{
kN0
,
-
k0_loops
*
kK0
});
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tile_0
=
load_tile
(
k_dram_window
);
gemm_0
(
s_acc
,
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
get_slice_tile
(
k_tile_1
=
load_tile
(
k_dram_window
);
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
});
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
...
@@ -427,8 +464,10 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -427,8 +464,10 @@ struct BlockFmhaPipelineQRKSVSAsync
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
tile_elementwise_inout
([](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
s
.
get_tile_distribution
());
// Pcompute{j}
...
@@ -641,8 +680,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -641,8 +680,7 @@ struct BlockFmhaPipelineQRKSVSAsync
});
});
}
}
}
}
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
// tail
// tail
{
{
block_sync_lds
();
block_sync_lds
();
...
@@ -653,7 +691,16 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -653,7 +691,16 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence
<
((
k1_loops
-
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
((
k1_loops
-
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
k1_loops
-
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{}));
sequence
<
(((
k1_loops
-
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{}));
}
}
}
while
(
++
i_total_loops
<
num_total_loop
);
__builtin_amdgcn_sched_barrier
(
0
);
if
(
i_total_loops
++
<
num_total_loop
)
{
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tile_0
);
}
}
while
(
i_total_loops
<
num_total_loop
);
// store lse
// store lse
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
View file @
21dc4596
...
@@ -11,6 +11,7 @@ namespace ck_tile {
...
@@ -11,6 +11,7 @@ namespace ck_tile {
struct
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
struct
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopy = */
true
,
/* AsyncCopy = */
true
,
/* NumPrefetchK = */
2
,
/* NumPrefetchV = */
2
>
/* NumPrefetchV = */
2
>
{
{
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -60,20 +61,16 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
...
@@ -60,20 +61,16 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
constexpr
index_t
BlockGemmK
=
(
KLoadOnce
&&
Problem
::
BlockFmhaShape
::
kQKHeaddim
==
using
GemmProblem
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
)
BlockGemmProblem
<
typename
Problem
::
QDataType
,
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
typename
Problem
::
KDataType
,
:
Problem
::
BlockFmhaShape
::
kK0
;
typename
Problem
::
SaccDataType
,
Problem
::
kNumGemm0Warps
*
get_warp_size
(),
using
GemmProblem
=
BlockGemmProblem
<
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
typename
Problem
::
QDataType
,
Problem
::
BlockFmhaShape
::
kN0
,
typename
Problem
::
KDataType
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
Problem
::
kNumGemm0Warps
*
get_warp_size
(),
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
BlockGemmK
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
View file @
21dc4596
...
@@ -11,6 +11,7 @@ namespace ck_tile {
...
@@ -11,6 +11,7 @@ namespace ck_tile {
using
BlockFmhaPipelineQRKSVSDefaultPolicy
=
using
BlockFmhaPipelineQRKSVSDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopy = */
false
,
/* AsyncCopy = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
21dc4596
...
@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQLoadOnce
=
false
;
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
false
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
View file @
21dc4596
...
@@ -12,6 +12,7 @@ namespace ck_tile {
...
@@ -12,6 +12,7 @@ namespace ck_tile {
struct
BlockFmhaPipelineQSKSVSDefaultPolicy
struct
BlockFmhaPipelineQSKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopy = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
21dc4596
...
@@ -276,21 +276,26 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -276,21 +276,26 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
}
}
};
};
template
<
bool
QLoadOnce_
,
bool
AsyncCopy_
,
index_t
NumPrefetchV_
>
template
<
bool
QLoadOnce_
,
bool
AsyncCopy_
,
index_t
NumPrefetchK_
,
index_t
NumPrefetchV_
>
struct
BlockFmhaPipelineQXKSVSCustomPolicy
:
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
struct
BlockFmhaPipelineQXKSVSCustomPolicy
:
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
{
{
static
constexpr
index_t
NumPrefetchK
=
NumPrefetchK_
;
static
constexpr
index_t
NumPrefetchV
=
NumPrefetchV_
;
static
constexpr
index_t
NumPrefetchV
=
NumPrefetchV_
;
// 1) When Async == true, we preload whole K-tile for next iteration using single LDS buffer,
// and preload V-slice for next unroll using multiple LDS buffers
// 2) When Async == false, we preload K-slice for next unroll using single LDS buffer, and
// preload V-slice for next unroll using single LDS buffer
static
constexpr
bool
AsyncCopy
=
AsyncCopy_
;
static
constexpr
bool
AsyncCopy
=
AsyncCopy_
;
static
constexpr
bool
KLoadOnce
=
AsyncCopy
;
using
QXPolicy
=
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
;
using
QXPolicy
=
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
;
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
GetNumKLdsBuffers
()
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
k0_loops
=
BlockFmhaShape
::
kQKHeaddim
/
BlockFmhaShape
::
kK0
;
return
min
(
NumPrefetchK
,
k0_loops
);
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
GetNumVLdsBuffers
()
CK_TILE_DEVICE
static
constexpr
auto
GetNumVLdsBuffers
()
{
{
...
@@ -317,8 +322,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -317,8 +322,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
KLoadOnce
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
KDataType
);
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
KDataType
);
...
@@ -382,29 +386,51 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -382,29 +386,51 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
}
}
/*
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{},
number<kKPack>{}))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{},
sequence<1>{}));
return k_lds_block_desc;
}
*/
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKSingleSmemElementSpaceSize
()
{
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
SingleKSize
=
[
&
]()
{
constexpr
index_t
kKPerBlock
=
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
KLoadOnce
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
KDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
>
{},
number
<
kKPack
>
{}),
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
number
<
8
>
{},
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
number
<
1
>
{});
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
return
(
kKPerBlock
/
kKPack
)
*
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
);
k_lds_block_desc_0
,
}();
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
return
SingleKSize
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -428,6 +454,48 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -428,6 +454,48 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return
SingleVSize
;
return
SingleVSize
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
KDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
constexpr
index_t
NumKLdsBuffers
=
GetNumKLdsBuffers
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumKLdsBuffers
>
{},
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
GetKSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumKLdsBuffers
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
0
,
2
,
3
>
{},
sequence
<
1
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
// 3d + padding
// 3d + padding
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsBlockDescriptor
()
...
@@ -532,8 +600,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -532,8 +600,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
KLoadOnce
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
KDataType
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment