Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3ee41b40
Commit
3ee41b40
authored
Jan 22, 2025
by
Qianfeng Zhang
Browse files
Re-implement qr_ks_vs_async pipeline by using kLoadOnce
parent
c0b90f13
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
415 additions
and
645 deletions
+415
-645
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+27
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
...ock_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
...litkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
+4
-8
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
+1
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+8
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+184
-246
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+75
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
+1
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+5
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
+1
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+99
-350
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
3ee41b40
...
@@ -1064,14 +1064,14 @@ struct FmhaFwdKernel
...
@@ -1064,14 +1064,14 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}
else
else
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}
}();
}();
const
auto
k_dram
=
[
&
]()
{
const
auto
k_dram
=
[
&
]()
{
...
@@ -1082,10 +1082,20 @@ struct FmhaFwdKernel
...
@@ -1082,10 +1082,20 @@ struct FmhaFwdKernel
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
sequence
<
false
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
return
pad_tensor_view
(
k_dram_naive
,
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}();
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel
...
@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_transposed
,
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
kPadHeadDimV
,
false
>
{});
}
}
else
else
{
{
...
@@ -1121,7 +1131,7 @@ struct FmhaFwdKernel
...
@@ -1121,7 +1131,7 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_naive
,
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
false
,
kPadSeqLenK
>
{});
}
}
}();
}();
...
@@ -1137,7 +1147,15 @@ struct FmhaFwdKernel
...
@@ -1137,7 +1147,15 @@ struct FmhaFwdKernel
{
i_m0
,
0
});
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window
=
make_tile_window
(
k_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tile_window
(
v_dram
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
View file @
3ee41b40
...
@@ -316,11 +316,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
...
@@ -316,11 +316,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// load Q from LDS
// load Q from LDS
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_lds_window_for_load
=
make_tile_window
(
auto
q_lds_window_for_load
=
q_lds
,
make_tile_window
(
q_lds
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
},
{
0
,
0
},
Policy
::
template
MakeQRegTileDistribution
<
Problem
,
decltype
(
gemm_0
)
>());
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
block_sync_lds
();
block_sync_lds
();
auto
q
=
load_tile
(
q_lds_window_for_load
);
auto
q
=
load_tile
(
q_lds_window_for_load
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
View file @
3ee41b40
...
@@ -13,15 +13,11 @@ namespace ck_tile {
...
@@ -13,15 +13,11 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
using
BasePolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
using
BasePolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
;
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -76,10 +72,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
...
@@ -76,10 +72,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
sequence
<
0
,
1
>>
{});
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegTileDistribution
()
{
{
return
BasePolicy
::
template
MakeQDramTileDistribution
<
Problem
,
BlockGemm
>();
return
BasePolicy
::
template
MakeQDramTileDistribution
<
Problem
>();
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
3ee41b40
...
@@ -180,11 +180,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -180,11 +180,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)
>());
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
auto
q
=
load_tile
(
q_dram_window
);
auto
q
=
load_tile
(
q_dram_window
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
View file @
3ee41b40
...
@@ -11,9 +11,7 @@ namespace ck_tile {
...
@@ -11,9 +11,7 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
struct
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
struct
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
3ee41b40
...
@@ -35,6 +35,9 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -35,6 +35,9 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
false
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
@@ -178,11 +181,11 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -178,11 +181,11 @@ struct BlockFmhaPipelineQRKSVS
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)
>());
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
auto
q
=
load_tile
(
q_dram_window
);
auto
q
=
load_tile
(
q_dram_window
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
3ee41b40
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
...
@@ -12,7 +11,6 @@
...
@@ -12,7 +11,6 @@
namespace
ck_tile
{
namespace
ck_tile
{
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
>
struct
BlockFmhaPipelineQRKSVSAsync
struct
BlockFmhaPipelineQRKSVSAsync
{
{
...
@@ -36,6 +34,9 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -36,6 +34,9 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
true
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
@@ -47,68 +48,51 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -47,68 +48,51 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
Problem
::
kPadHeadDimV
==
true
);
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentQ
=
static
constexpr
index_t
kAlignmentK
=
Policy
::
template
GetAlignmentK
<
Problem
>();
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
Policy
::
template
GetAlignmentV
<
Problem
>();
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
}();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static
constexpr
auto
R_LOG2E
=
1.0
/
log2e_v
<
SaccDataType
>
;
#endif
static
constexpr
index_t
kBlockPerCu
=
[]()
{
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
return
Problem
::
kBlockPerCu
;
else
else
{
{
// minimize occupancy
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)
{
return
1
;
}
if
constexpr
(
kQKHeaddim
<=
32
)
if
constexpr
(
kQKHeaddim
<=
32
)
{
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
return
1
;
else
return
2
;
return
2
;
}
}
else
if
constexpr
(
kQKHeaddim
<=
64
)
else
if
constexpr
(
kQKHeaddim
<=
64
)
{
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
return
2
;
else
return
3
;
}
}
else
if
constexpr
(
kQKHeaddim
<=
128
)
else
if
constexpr
(
kQKHeaddim
<=
128
)
{
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
1
;
}
}
else
if
constexpr
(
kQKHeaddim
<=
256
)
else
if
constexpr
(
kQKHeaddim
<=
256
)
{
{
...
@@ -142,10 +126,10 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -142,10 +126,10 @@ struct BlockFmhaPipelineQRKSVSAsync
typename
OAccElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*
K0
tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*
kSubQKHeaddim
tile
const
QElementFunction
&
q_element_func
,
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*
K0
tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*
kSubQKHeaddim
tile
const
KElementFunction
&
/*
k_element_func
*/
,
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
...
@@ -170,50 +154,28 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -170,50 +154,28 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kSubQKHeaddim
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
"wrong!"
);
constexpr
auto
LdsSeq
=
Policy
::
template
GetLdsBuffer
Sequence
<
Problem
>();
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
Get
NumV
LdsBuffer
s
<
Problem
>();
// K tile in LDS
// K tile in LDS
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
auto
k_lds_store
=
generate_tuple
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
[
&
](
auto
i_buf
)
{
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
return
make_tile_window
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
k_lds_window
=
k_lds_ptr
,
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
)),
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kSubQKHeaddim
>
{}),
{
0
,
0
});
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto
k_lds_load
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#else
auto
k_lds_Load_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>());
auto
k_lds_load
=
make_tile_window
(
k_lds_Load_view
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
#endif
// V tile in LDS
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
reinterpret_cast
<
VDataType
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
...
@@ -222,21 +184,13 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -222,21 +184,13 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
q_dram_window
.
init_raw
();
auto
q
=
load_tile
(
q_dram_window
);
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
auto
s_acc
=
SaccBlockTileType
{};
...
@@ -262,7 +216,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -262,7 +216,6 @@ struct BlockFmhaPipelineQRKSVSAsync
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
clear_tile
(
l
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
...
@@ -283,13 +236,11 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -283,13 +236,11 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return
o_acc
;
return
o_acc
;
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// make sure sched_barrier(0) for this check
}
}
auto
k_dram_block_window
=
auto
k_dram_block_window
=
...
@@ -303,16 +254,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -303,16 +254,7 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window
.
get_window_origin
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
// load
k_dram_window
.
init_raw
();
auto
k_tile
=
load_tile
(
k_dram_window
);
constexpr
auto
k_oob_ck
=
bool_constant
<
true
>
{};
constexpr
auto
k_pre_np
=
[
&
]()
{
if
constexpr
(
kPadSeqLenK
&&
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)))
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
auto
bias_dram_window
=
...
@@ -330,81 +272,58 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -330,81 +272,58 @@ struct BlockFmhaPipelineQRKSVSAsync
{
0
,
seqlen_k_start
},
// TODO: hdim split?
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_of_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
// prefetch K tile
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
static_assert
(
1
<=
k1_loops
);
// main loop
do
do
{
{
// STAGE 1, QK gemm
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
);
// initialize C
if
constexpr
(
k0_loops
>
1
)
{
store_tile
(
k_lds_window
,
k_tile
);
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_pre_np
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_of_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
if
(
i_total_loops
<
num_total_loop
-
1
)
{
move_tile_window
(
k_dram_window
,
{
kN0
,
0
});
k_tile
=
load_tile
(
k_dram_window
);
}
__builtin_amdgcn_sched_barrier
(
0
);
// for kQKHeaddim == 96 (kSubQKHeaddim == 128), we need to use k0_loops
if
constexpr
(
kQKHeaddim
==
kSubQKHeaddim
)
{
gemm_0
(
s_acc
,
q
,
k_lds_window
);
}
else
{
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
get_slice_tile
(
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
get_slice_tile
(
k_lds_window
,
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
i_k0
>
{})
>
{}]);
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
});
});
}
}
// TODO: this to fix a bug when loop smaller than 2,
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
// the following fence/barrier will be scheduled inside 1st loop
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
auto
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
>
{}]);
#else
auto
v_buf
=
load_tile
(
v_dram_window
);
// prefetch load v tile
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
}
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
...
@@ -457,7 +376,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -457,7 +376,6 @@ struct BlockFmhaPipelineQRKSVSAsync
k_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kM0
>
{},
number
<
kN0
>
{});
number
<
kN0
>
{});
if
(
need_perpixel_check
)
if
(
need_perpixel_check
)
{
{
set_tile_if
(
set_tile_if
(
...
@@ -484,7 +402,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -484,7 +402,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
s
.
get_tile_distribution
());
// Pcompute{j}
__builtin_amdgcn_sched_barrier
(
0
x7F
);
__builtin_amdgcn_sched_barrier
(
0
);
// store & prefetch next v, after the max reduction
// store & prefetch next v, after the max reduction
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -493,9 +411,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -493,9 +411,7 @@ struct BlockFmhaPipelineQRKSVSAsync
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{});
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
store_tile
(
v_lds_window_tmp
,
v_lds_window_tmp
,
...
@@ -504,26 +420,25 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -504,26 +420,25 @@ struct BlockFmhaPipelineQRKSVSAsync
else
else
{
{
auto
v_lds_window_tmp
=
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{});
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store the prefetch
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store the prefetch
}
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
if
constexpr
(
k1_loops
>
1
)
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
NumVLdsBuffers
>
1
)
{
{
move_tile_window
(
v_buf
=
load_tile
(
v_dram_window
);
// load next v_buf
v_dram_window
,
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
}
}
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
. alibi does not have this problem
/// consideration
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
FmhaMask
::
IsMasking
)
{
{
...
@@ -597,51 +512,86 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -597,51 +512,86 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
auto
randval_ptr
=
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeV
<
Problem
>();
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
randval_ptr
,
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
}
const
auto
p
=
[
&
]()
{
const
auto
p
=
if
constexpr
(
std
::
is_same_v
<
PDataType
,
fp16_t
>
)
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
return
impl
::
cast_tile_pk_fp16_fp32
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
else
return
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
}();
// STAGE 3, KV gemm
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
if
constexpr
(
k1_loops
>
1
)
{
if
constexpr
(
NumVLdsBuffers
==
1
)
{
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
v_buf
=
load_tile
(
v_dram_window
);
// load next v_buf
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
i_k1
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
((
i_k1
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{}));
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
block_sync_lds
();
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
{
v_buf
=
load_tile
(
auto
v_lds_window_tmp
=
get_slice_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
block_sync_lds
();
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
}
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
else
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
>
0
&&
i_k1
<
k1_loops
-
1
)
v_buf
=
load_tile
(
v_dram_window
);
// load next v_buf
block_sync_lds
();
block_sync_lds
();
gemm_1
(
o_acc
,
gemm_1
(
o_acc
,
get_slice_tile
(
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
get_slice_tile
(
v_lds_window
,
v_lds_window
,
sequence
<
(
i_k1
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{}))
*
kN1
,
0
>
{},
sequence
<
((
i_k1
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{}));
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
)
*
kN1
,
0
>
{},
sequence
<
(
(
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}
)
+
1
)
*
kN1
,
kK1
>
{});
sequence
<
(
((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
v_shuffle_tmp
));
// store the prefetch
...
@@ -650,44 +600,32 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -650,44 +600,32 @@ struct BlockFmhaPipelineQRKSVSAsync
{
{
auto
v_lds_window_tmp
=
get_slice_tile
(
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
}
}
if
constexpr
(
i_k1
<
k1_loops
-
1
)
if
constexpr
(
i_k1
>
0
&&
i_k1
<
k1_loops
-
1
)
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
});
}
}
i_total_loops
++
;
}
if
(
i_total_loops
<
num_total_loop
)
{
// move K tile windows
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
.
set_window_origin
(
k_dram_block_window
.
get_window_origin
());
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
// tail
{
{
block_sync_lds
();
block_sync_lds
();
gemm_1
(
gemm_1
(
o_acc
,
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
get_slice_tile
(
get_slice_tile
(
v_lds_window
,
v_lds_window
,
sequence
<
((
k1_loops
-
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{}
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(((
k1_loops
-
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{}));
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{})
+
1
)
*
kN1
,
kK1
>
{})
);
block_sync_lds
(
);
}
}
}
while
(
i_total_loops
<
num_total_loop
);
}
while
(
++
i_total_loops
<
num_total_loop
);
// store lse
// store lse
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
...
@@ -701,11 +639,11 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -701,11 +639,11 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
R
_LOG2E
+
log
(
l_
[
i_idx
]);
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C
_LOG2E
+
log
(
l_
[
i_idx
]);
}
}
else
else
{
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
*
R
_LOG2E
+
log
(
l_
[
i_idx
]);
lse
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
/
C
_LOG2E
+
log
(
l_
[
i_idx
]);
}
}
#else
#else
lse
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
lse
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
View file @
3ee41b40
...
@@ -8,12 +8,80 @@
...
@@ -8,12 +8,80 @@
namespace
ck_tile
{
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
struct
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
using
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
=
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopy = */
true
,
/* AsyncCopyK = */
true
,
/* NumPrefetchV = */
2
>
/* AsyncCopyV = */
false
,
{
/* NumPrefetchK = */
3
,
template
<
typename
Problem
>
/* NumPrefetchV = */
3
>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
constexpr
index_t
BlockGemmK
=
(
KLoadOnce
&&
Problem
::
BlockFmhaShape
::
kQKHeaddim
==
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
)
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kNumGemm0Warps
*
get_warp_size
(),
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
BlockGemmK
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
4
||
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaF16F16F32M4N64K16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaBf16Bf16F32M4N64K16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
static_assert
(
WarpGemmM
==
32
);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
swizzle_factor
>
{};
}
// TODO - bf8_t
}();
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
else
return
BlockGemmARegBSmemCRegOneWarpV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
View file @
3ee41b40
...
@@ -8,12 +8,9 @@
...
@@ -8,12 +8,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaPipelineQRKSVSDefaultPolicy
=
using
BlockFmhaPipelineQRKSVSDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
3ee41b40
...
@@ -34,6 +34,9 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -34,6 +34,9 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQLoadOnce
=
false
;
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
false
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
@@ -94,6 +97,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -94,6 +97,8 @@ struct BlockFmhaPipelineQSKSVS
{
{
return
1
;
return
1
;
}
}
else
return
1
;
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
View file @
3ee41b40
...
@@ -11,9 +11,7 @@ namespace ck_tile {
...
@@ -11,9 +11,7 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
struct
BlockFmhaPipelineQSKSVSDefaultPolicy
struct
BlockFmhaPipelineQSKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
/* AsyncCopyK = */
false
,
/* AsyncCopy = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
/* NumPrefetchV = */
1
>
{
{
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
3ee41b40
...
@@ -17,9 +17,6 @@
...
@@ -17,9 +17,6 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
namespace
ck_tile
{
namespace
ck_tile
{
template
<
bool
QLoadOnce_
>
template
<
bool
QLoadOnce_
>
...
@@ -50,9 +47,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -50,9 +47,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
return
min
(
MaxVectorSize
,
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
);
return
min
(
MaxVectorSize
,
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
);
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
return
BlockGemm
::
template
MakeABlockTileDistribution
<
return
BlockGemm
::
template
MakeABlockTileDistribution
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
>();
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
>();
...
@@ -277,72 +276,32 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -277,72 +276,32 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
}
}
};
};
// This pipeline is qkv all located in LDS
template
<
bool
QLoadOnce_
,
bool
AsyncCopy_
,
index_t
NumPrefetchV_
>
template
<
bool
QLoadOnce_
,
bool
AsyncCopyK_
,
bool
AsyncCopyV_
,
index_t
NumPrefetchK_
,
index_t
NumPrefetchV_
>
struct
BlockFmhaPipelineQXKSVSCustomPolicy
:
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
struct
BlockFmhaPipelineQXKSVSCustomPolicy
:
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
{
{
static
constexpr
bool
AsyncCopyK
=
AsyncCopyK_
;
static
constexpr
index_t
NumPrefetchV
=
NumPrefetchV_
;
static
constexpr
bool
AsyncCopyV
=
AsyncCopyV_
;
// TODO: this not supported yet
static
constexpr
index_t
NumPrefetchK
=
NumPrefetchK_
;
static
constexpr
index_t
NumPrefetchV
=
NumPrefetchK_
;
using
QXPolicy
=
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
;
template
<
index_t
k_prefetches_
,
index_t
v_prefetches_
,
index_t
k_loops_
,
index_t
v_loops_
>
struct
LdsBufferSequence
{
static
constexpr
auto
Make
()
{
return
transform_sequences
(
[
&
](
auto
i
)
{
if
(
i
<
k_loops_
)
return
i
%
k_prefetches_
;
return
(
i
-
k_loops_
)
%
v_prefetches_
;
},
typename
arithmetic_sequence_gen
<
0
,
k_loops_
+
v_loops_
,
1
>::
type
{});
};
using
type
=
remove_cvref_t
<
decltype
(
Make
())
>
;
// 1) When Async == true, we preload whole K-tile for next iteration using single LDS buffer,
};
// and preload V-slice for next unroll using multiple LDS buffers
//
clang-format off
//
2) When Async == false, we preload K-slice for next unroll using single LDS buffer, and
template
<
>
struct
// preload V-slice for next unroll using single LDS buffer
LdsBufferSequence
<
3
,
3
,
4
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
0
,
1
,
2
,
0
>
;
}
;
static
constexpr
bool
AsyncCopy
=
AsyncCopy_
;
template
<
>
struct
static
constexpr
bool
KLoadOnce
=
AsyncCopy
;
LdsBufferSequence
<
3
,
3
,
4
,
2
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
using
QXPolicy
=
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
;
LdsBufferSequence
<
3
,
3
,
2
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
3
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
2
,
2
>
{
using
type
=
sequence
<
1
,
2
,
1
,
0
>
;};
// clang-format on
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
GetLdsBuffer
Sequence
()
CK_TILE_DEVICE
static
constexpr
auto
Get
NumV
LdsBuffer
s
()
{
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
return
typename
LdsBufferSequence
<
NumPrefetchK
,
NumPrefetchV
,
k0_loops
,
k1_loops
>::
type
{}
;
return
min
(
NumPrefetchV
,
k1_loops
)
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -356,15 +315,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -356,15 +315,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
if
constexpr
(
AsyncCopyK
)
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
{
constexpr
index_t
kKPerBlock
=
return
4
/
sizeof
(
KDataType
);
KLoadOnce
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
}
else
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
KDataType
);
{
return
16
/
sizeof
(
KDataType
);
constexpr
index_t
ElemPerThread
=
(
kNPerBlock
*
kKPerBlock
)
/
kBlockSize
;
}
static_assert
(
0
<
ElemPerThread
);
return
min
(
ElemPerThread
,
MaxVectorSize
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -385,14 +345,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -385,14 +345,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
ElemPerThread
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kMaxVecLoad
=
constexpr
index_t
kMaxVecLoad
=
min
(
total_pixels
,
static_cast
<
index_t
>
(
16
/
sizeof
(
VDataType
)));
min
(
ElemPerThread
,
static_cast
<
index_t
>
(
16
/
sizeof
(
VDataType
)));
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
VDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
VDataType
);
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
constexpr
index_t
kVecLoad
=
((
ElemPerThread
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
:
(
ElemPerThread
/
kMinVecLoad
);
return
kVecLoad
;
return
kVecLoad
;
}
}
...
@@ -422,60 +382,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -422,60 +382,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSingleSmemElementSpaceSize
()
{
// this function assume K/V can share smem
constexpr
index_t
SingleKSize
=
[
&
]()
{
if
constexpr
(
!
AsyncCopyK
)
{
return
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
else
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
return
NumIssues
*
NumWarps
*
(
warpSize
*
KVector
+
kPad
);
}
}();
constexpr
index_t
SingleVSize
=
[
&
]()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
return
(
kKPerBlock
/
kKPack
)
*
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
);
}();
return
max
(
SingleKSize
,
SingleVSize
);
}
// TODO: this is used for non async copy desc. unify in the future
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPerBlock
=
KLoadOnce
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
...
@@ -495,164 +407,26 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -495,164 +407,26 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return
k_lds_block_desc
;
return
k_lds_block_desc
;
}
}
template
<
typename
Problem
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsStoreBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// how many lane (within a wave) to load K
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// how many groups (within a wave), they may load different N, but same K
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor_with_offset
(
make_tuple
(
number
<
NumIssues
>
{},
// n0
number
<
LaneGroups
>
{},
// n1
number
<
NumWarps
>
{},
// n2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
kKPerBlock
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
KVector
>
{},
number
<
1
>
{}),
number
<
IBuf
*
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
KVector
>
{},
number
<
1
>
{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr
auto
k_lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
k_lds_block_desc_issues_warps_lanes
;
}
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template
<
typename
Problem
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsLoadBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor_with_offset
(
make_tuple
(
number
<
NumIssues
>
{},
// n0
number
<
NumWarps
>
{},
// n2
number
<
LaneGroups
>
{},
// n1
number
<
kKPerBlock
/
KPack
>
{},
// k0
number
<
KPack
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
kKPerBlock
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
IBuf
*
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
2
,
1
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
#else
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsLoadBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVSingleSmemElementSpaceSize
()
{
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
SingleVSize
=
[
&
]()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
static_assert
(
kKPerBlock
%
kKPack
==
0
);
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
// constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
// constexpr index_t SingleVSize =
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr
index_t
BufferSize
=
GetSingleSmemElementSpaceSize
<
Problem
>
();
// max(SingleKSize, SingleVSize);
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumPrefetchK
>
{},
// num_buffers
number
<
NumIssues
>
{},
// n0
number
<
NumWarps
>
{},
// n2
number
<
LaneGroups
>
{},
// n1
number
<
kKPerBlock
/
KPack
>
{},
// k0
number
<
KPack
>
{}),
// k1
make_tuple
(
number
<
BufferSize
>
{},
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
kKPerBlock
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
return
(
kKPerBlock
/
kKPack
)
*
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
);
k_lds_block_desc_0
,
}();
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumPrefetchK
>
{},
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
3
,
2
>
{},
sequence
<
4
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
return
SingleVSize
;
}
}
#endif
// 3d + padding
// 3d + padding
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -669,13 +443,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -669,13 +443,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
constexpr
index_t
NumVLdsBuffers
=
GetNumVLdsBuffers
<
Problem
>
();
constexpr
auto
v_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
v_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Num
PrefetchV
>
{},
make_tuple
(
number
<
Num
VLdsBuffers
>
{},
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
number
<
kKPack
>
{}),
make_tuple
(
number
<
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
make_tuple
(
number
<
Get
V
SingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
kKPack
>
{},
...
@@ -687,7 +463,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -687,7 +463,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
v_lds_block_desc_0
,
v_lds_block_desc_0
,
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
Num
PrefetchV
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
number
<
Num
VLdsBuffers
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
0
,
2
,
3
>
{},
sequence
<
1
,
4
>
{}),
make_tuple
(
sequence
<
0
,
2
,
3
>
{},
sequence
<
1
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
@@ -696,28 +472,26 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -696,28 +472,26 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
V
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
{
{
// TODO: assume Q is in register
return
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
// TODO: assume K/V has same data type
sizeof
(
typename
Problem
::
KDataType
);
constexpr
index_t
single_smem_size
=
GetSingleSmemElementSpaceSize
<
Problem
>
()
*
sizeof
(
typename
Problem
::
KDataType
);
return
QXPolicy
::
template
GetSmemSizeQ
<
Problem
>()
+
single_smem_size
*
max
(
NumPrefetchK
,
NumPrefetchV
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
V
()
{
{
if
constexpr
(
AsyncCopyK
)
return
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
{
sizeof
(
typename
Problem
::
VDataType
);
return
GetSmemSizeKV
<
Problem
>
()
+
GetSmemSizeDropout
<
Problem
>
(
0
);
}
}
else
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
(
0
));
// assume Q can reuse the shared memory with K or V
}
return
max
(
QXPolicy
::
template
GetSmemSizeQ
<
Problem
>(),
GetSmemSizeK
<
Problem
>
()
+
GetSmemSizeV
<
Problem
>
())
+
GetSmemSizeDropout
<
Problem
>
(
0
);
}
}
// this method is only available when Problem::kHasDropout is present
// this method is only available when Problem::kHasDropout is present
...
@@ -753,60 +527,35 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -753,60 +527,35 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
if
constexpr
(
!
AsyncCopyK
)
{
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPerBlock
=
KLoadOnce
?
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
:
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
K
Vector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
Max
Vector
Size
=
16
/
sizeof
(
KDataType
);
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
ElemPerThread
=
(
kNPerBlock
*
kKPerBlock
)
/
kBlockSize
;
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
kMaxVecLoad
=
min
(
ElemPerThread
,
MaxVectorSize
);
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
index_t
N0
=
NumIssues
;
constexpr
index_t
KPerThread
=
kMaxVecLoad
;
constexpr
index_t
N1
=
LaneGroups
;
constexpr
index_t
KThreads
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
N
2
=
NumWarp
s
;
constexpr
index_t
N
ThreadPerWarp
=
get_warp_size
()
/
KThread
s
;
constexpr
index_t
K0
=
LanesPerK
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
()
;
constexpr
index_t
K1
=
KVector
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
(
NThreadPerWarp
*
NumWarps
)
;
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
NPerThread
,
NumWarps
,
NThreadPerWarp
>
,
sequence
<
KThreads
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
...
@@ -822,9 +571,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -822,9 +571,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
ElemPerThread
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
static_assert
(
ElemPerThread
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
K3
=
ElemPerThread
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
...
@@ -895,9 +644,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -895,9 +644,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
ElemPerThread
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
static_assert
(
ElemPerThread
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
K3
=
ElemPerThread
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment