Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3ee41b40
Commit
3ee41b40
authored
Jan 22, 2025
by
Qianfeng Zhang
Browse files
Re-implement qr_ks_vs_async pipeline by using kLoadOnce
parent
c0b90f13
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
415 additions
and
645 deletions
+415
-645
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+27
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
...ock_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
...litkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
+4
-8
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
+1
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+8
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+184
-246
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+75
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
+1
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+5
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
+1
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+99
-350
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
3ee41b40
...
@@ -1064,14 +1064,14 @@ struct FmhaFwdKernel
...
@@ -1064,14 +1064,14 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}
else
else
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}
}();
}();
const
auto
k_dram
=
[
&
]()
{
const
auto
k_dram
=
[
&
]()
{
...
@@ -1082,10 +1082,20 @@ struct FmhaFwdKernel
...
@@ -1082,10 +1082,20 @@ struct FmhaFwdKernel
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
sequence
<
false
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
return
pad_tensor_view
(
k_dram_naive
,
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}();
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel
...
@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_transposed
,
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
kPadHeadDimV
,
false
>
{});
}
}
else
else
{
{
...
@@ -1121,7 +1131,7 @@ struct FmhaFwdKernel
...
@@ -1121,7 +1131,7 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_naive
,
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
false
,
kPadSeqLenK
>
{});
}
}
}();
}();
...
@@ -1137,7 +1147,15 @@ struct FmhaFwdKernel
...
@@ -1137,7 +1147,15 @@ struct FmhaFwdKernel
{
i_m0
,
0
});
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window
=
make_tile_window
(
k_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tile_window
(
v_dram
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
View file @
3ee41b40
...
@@ -316,11 +316,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
...
@@ -316,11 +316,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// load Q from LDS
// load Q from LDS
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_lds_window_for_load
=
make_tile_window
(
auto
q_lds_window_for_load
=
q_lds
,
make_tile_window
(
q_lds
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
},
{
0
,
0
},
Policy
::
template
MakeQRegTileDistribution
<
Problem
,
decltype
(
gemm_0
)
>());
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
block_sync_lds
();
block_sync_lds
();
auto
q
=
load_tile
(
q_lds_window_for_load
);
auto
q
=
load_tile
(
q_lds_window_for_load
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
View file @
3ee41b40
...
@@ -13,15 +13,11 @@ namespace ck_tile {
...
@@ -13,15 +13,11 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
using
BasePolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
using
BasePolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
;
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -76,10 +72,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
...
@@ -76,10 +72,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
sequence
<
0
,
1
>>
{});
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegTileDistribution
()
{
{
return
BasePolicy
::
template
MakeQDramTileDistribution
<
Problem
,
BlockGemm
>();
return
BasePolicy
::
template
MakeQDramTileDistribution
<
Problem
>();
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
3ee41b40
...
@@ -180,11 +180,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -180,11 +180,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)
>());
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
auto
q
=
load_tile
(
q_dram_window
);
auto
q
=
load_tile
(
q_dram_window
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
View file @
3ee41b40
...
@@ -11,9 +11,7 @@ namespace ck_tile {
...
@@ -11,9 +11,7 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
struct
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
struct
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
3ee41b40
...
@@ -35,6 +35,9 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -35,6 +35,9 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
false
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
@@ -178,11 +181,11 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -178,11 +181,11 @@ struct BlockFmhaPipelineQRKSVS
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)
>());
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
auto
q
=
load_tile
(
q_dram_window
);
auto
q
=
load_tile
(
q_dram_window
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
3ee41b40
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
View file @
3ee41b40
...
@@ -8,12 +8,80 @@
...
@@ -8,12 +8,80 @@
namespace
ck_tile
{
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
struct
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
using
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
=
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopy = */
true
,
/* AsyncCopyK = */
true
,
/* NumPrefetchV = */
2
>
/* AsyncCopyV = */
false
,
{
/* NumPrefetchK = */
3
,
template
<
typename
Problem
>
/* NumPrefetchV = */
3
>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
constexpr
index_t
BlockGemmK
=
(
KLoadOnce
&&
Problem
::
BlockFmhaShape
::
kQKHeaddim
==
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
)
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kNumGemm0Warps
*
get_warp_size
(),
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
BlockGemmK
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
4
||
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaF16F16F32M4N64K16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaBf16Bf16F32M4N64K16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
static_assert
(
WarpGemmM
==
32
);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
swizzle_factor
>
{};
}
// TODO - bf8_t
}();
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
else
return
BlockGemmARegBSmemCRegOneWarpV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
View file @
3ee41b40
...
@@ -8,12 +8,9 @@
...
@@ -8,12 +8,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaPipelineQRKSVSDefaultPolicy
=
using
BlockFmhaPipelineQRKSVSDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
3ee41b40
...
@@ -34,6 +34,9 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -34,6 +34,9 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQLoadOnce
=
false
;
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
false
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
@@ -94,6 +97,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -94,6 +97,8 @@ struct BlockFmhaPipelineQSKSVS
{
{
return
1
;
return
1
;
}
}
else
return
1
;
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
View file @
3ee41b40
...
@@ -11,9 +11,7 @@ namespace ck_tile {
...
@@ -11,9 +11,7 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
struct
BlockFmhaPipelineQSKSVSDefaultPolicy
struct
BlockFmhaPipelineQSKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
3ee41b40
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment