Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f84e2020
"megatron/training/arguments.py" did not exist on "941a793fce69dd4a8a6845adc14c1ab17a3fa87b"
Unverified
Commit
f84e2020
authored
Aug 26, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Aug 26, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
408534d4
25935b57
Changes
175
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4338 additions
and
2285 deletions
+4338
-2285
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+555
-332
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+0
-54
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+2
-4
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+21
-18
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+12
-13
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
...e/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
+141
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
...ude/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
+3
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+782
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+1037
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
.../fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
+0
-821
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+1531
-954
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
...k_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
+2
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
...ile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
+37
-3
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+10
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+3
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+202
-0
No files found.
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
f84e2020
...
...
@@ -23,13 +23,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
struct
FmhaBwdDQDKDVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
...
...
@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
static
constexpr
bool
kIsDeterministic
=
FmhaPipeline
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -74,8 +73,11 @@ struct FmhaBwdDQDKDVKernel
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
gbr0
=
typename
bfs
::
Gemm0BlockWarps
;
using
gbr1
=
typename
bfs
::
Gemm1BlockWarps
;
using
gbr4
=
typename
bfs
::
Gemm4BlockWarps
;
using
gwt0
=
typename
bfs
::
Gemm0WarpTile
;
using
gwt1
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
...
...
@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel
return
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK2
)
+
"x"
+
_TS_
(
bfs
::
kK3
)
+
"x"
+
_TS_
(
bfs
::
kK4
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
);
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel
const
void
*
lse_ptr
;
const
void
*
do_ptr
;
const
void
*
d_ptr
;
void
*
dq_ptr
;
void
*
dq_
acc_
ptr
;
void
*
dk_ptr
;
void
*
dv_ptr
;
...
...
@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
nhead_ratio_qk
;
float
raw_scale
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
;
#endif
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
...
...
@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
};
struct
FmhaBwdCommonBiasKargs
...
...
@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdBatchModeKargs
:
FmhaBwdCommonKargs
,
...
...
@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
};
...
...
@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
...
...
@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
...
...
@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
...
...
@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
...
...
@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
seqlen_q
,
...
...
@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
...
...
@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dv
};
...
...
@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
...
...
@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
...
...
@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch
_stride_
lsed
,
ck_tile
::
index_t
split
_stride_
dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
-
1
,
// seqlen will be updated by another pointer
...
...
@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
...
...
@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
...
...
@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_k_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
FmhaPipeline
::
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_k
);
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
...
...
@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dbias
=
0
;
...
...
@@ -561,7 +612,8 @@ struct FmhaBwdDQDKDVKernel
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_lsed
=
query_start
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
...
...
@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
key_start
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -608,6 +660,7 @@ struct FmhaBwdDQDKDVKernel
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
...
...
@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
k
+
batch_offset_dk
;
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_v
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
v
+
batch_offset_dv
;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
...
...
@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
const
auto
q_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
const
auto
q_dram
=
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
qt_dram_naive
=
transform_tensor_view
(
q_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
qt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
}();
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
...
...
@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
const
auto
k_dram
=
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
const
auto
kt_dram_naive
=
transform_tensor_view
(
k_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
kt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
}();
const
auto
lse_dram
=
[
&
]()
{
...
...
@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
const
auto
do_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
{
return
pad_tensor_view
(
const
auto
do_dram
=
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
}();
const
auto
dot_dram_naive
=
transform_tensor_view
(
do_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
dot_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
const
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
auto
qt_dram_window
=
make_tile_window
(
qt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
i_n0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{});
}(),
{
0
,
i_n0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
i_n0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
0
,
0
});
auto
dot_dram_window
=
make_tile_window
(
dot_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
}(),
auto
dq_dram_window
=
[
&
,
i_tile_n_
=
i_tile_n
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
static_cast
<
long_index_t
>
(
i_tile_n_
)
*
kargs
.
split_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
else
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
}();
auto
lse_dram_window
=
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
...
...
@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read?
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
...
...
@@ -1035,33 +947,32 @@ struct FmhaBwdDQDKDVKernel
// dropout
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
else
{
return
FmhaDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel
}();
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
qt_dram_window
,
k_dram_window
,
kt_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
do_dram_window
,
dot_dram_window
,
lse_dram_window
,
d_dram_window
,
dq_dram_window
,
...
...
@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel
mask
,
position_encoding
,
kargs
.
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs
.
scale
,
#endif
rp_undrop
,
scale_rp_undrop
,
smem_ptr
,
...
...
@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdOGradDotO_
>
template
<
typename
FmhaBwdOGradDotO_
>
struct
FmhaBwdOGradDotOKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
...
...
@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
...
...
@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
nhead_stride_d
},
batch_stride_do
,
batch_stride_o
};
batch_stride_o
,
batch_stride_d
};
return
kargs
;
}
...
...
@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
batch_stride_d
)
ck_tile
::
index_t
nhead_stride_d
)
{
Kargs
kargs
{{
o_ptr
,
do_ptr
,
...
...
@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
nhead_stride_d
},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
return
kargs
;
...
...
@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_q
);
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
...
...
@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
batch_offset_d
=
query_start
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel
}
};
template
<
typename
FmhaBwdConvertQGrad_
>
struct
FmhaBwdConvertQGradKernel
{
using
FmhaBwdConvertQGrad
=
ck_tile
::
remove_cvref_t
<
FmhaBwdConvertQGrad_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdConvertQGrad
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdConvertQGrad
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kM0
=
FmhaBwdConvertQGrad
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
FmhaBwdConvertQGrad
::
kN0
;
static
constexpr
ck_tile
::
index_t
kQKHeaddim
=
FmhaBwdConvertQGrad
::
kQKHeaddim
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
AccDataType
>
;
using
QGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
QGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaBwdConvertQGrad
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaBwdConvertQGrad
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaBwdConvertQGrad
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
FmhaBwdConvertQGrad
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_convert_dq_d"
)
+
_TS_
(
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QGradDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
)
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
))
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template
<
ck_tile
::
index_t
I
>
struct
FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdConvertQGradCommonKargs
{
const
void
*
dq_acc_ptr
;
void
*
dq_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdConvertQGradBatchModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradGroupModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdConvertQGradGroupModeKargs
,
FmhaBwdConvertQGradBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dq
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
batch_stride_dq
,
batch_stride_dq_acc
};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
)};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
GetTileIndex
();
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_dq
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_dq
=
query_start
*
kargs
.
stride_dq
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
if
constexpr
(
kIsDeterministic
)
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_dq
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_dq
+
batch_offset_dq
;
// dQAcc/dQ DRAM and DRAM window
const
auto
dq_acc_dram
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
nsplits
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
dq_acc_dram_window
=
[
&
]()
{
if
constexpr
(
kIsDeterministic
)
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
i_m0
,
0
});
}
else
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
}
}();
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
if
constexpr
(
kIsDeterministic
)
{
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
,
nsplits
);
}
else
{
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaBwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
kBlockSize
>
struct
FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kBlockSize
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
f84e2020
...
...
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
...
...
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
}
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
query_start
;
}
if
constexpr
(
kHasDropout
)
{
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
f84e2020
...
...
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
...
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
struct
GroupModeKargs
...
...
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
batch_stride_o
};
batch_stride_o
,
batch_stride_lse_acc
};
if
constexpr
(
kStoreLSE
)
{
...
...
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
{
...
...
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
query_start
;
}
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
@@ -307,6 +304,12 @@ struct FmhaFwdSplitKVCombineKernel
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
// for simplicity, batch stride we just modify the pointer
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
f84e2020
...
...
@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
...
@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
struct
GroupModeKargs
...
...
@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{},
// placeholder for dropout
batch_stride_q
,
batch_stride_k
,
batch_stride_v
};
batch_stride_v
,
batch_stride_lse_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
...
...
@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
...
...
@@ -524,6 +521,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
...
@@ -567,6 +565,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
0 → 100644
View file @
f84e2020
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdConvertQGrad
{
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
static
constexpr
index_t
kM0
=
Problem
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
kN0
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
kQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
static
constexpr
index_t
kAlignmentQGradAcc
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGradAcc
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGrad
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
// Convert only
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradDramTileDistribution
<
Problem
>());
auto
dq_acc
=
load_tile
(
dq_acc_dram_window
);
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
// Reduce + Convert
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
index_t
nsplits
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>());
auto
dq_acc
=
decltype
(
load_tile
(
dq_acc_dram_window
)){};
clear_tile
(
dq_acc
);
constexpr
auto
dq_acc_spans
=
decltype
(
dq_acc
)
::
get_distributed_spans
();
index_t
i_total_loops
=
0
;
auto
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
do
{
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
i_total_loops
+=
1
;
}
while
(
i_total_loops
<
(
nsplits
-
1
));
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
// declare dq
constexpr
auto
dq_converted_dstr
=
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>();
auto
dq_converted
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_converted_dstr
);
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_converted
(
n_i_j_idx
)
=
type_convert
<
QGradDataType
>
(
dq_acc
[
n_i_j_idx
]);
});
});
});
constexpr
auto
dq_dstr
=
Policy
::
template
MakePostQGradDramTileDistribution
<
Problem
>();
auto
dq
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_dstr
);
dq
.
get_thread_buffer
()
=
dq_converted
.
get_thread_buffer
();
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
View file @
f84e2020
...
...
@@ -4,11 +4,11 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dot_do_o
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
OGradDotO
DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
Pipeline
DefaultPolicy
>
struct
BlockFmhaBwdOGradDotO
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
...
...
@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
Grad
<
Problem
>();
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
deleted
100644 → 0
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// These templates are not used here.
using
BlockFmhaBwdOGradDotODefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
false
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
false
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_
qs_ks_vr_dos
.hpp
→
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_
kr_ktr_vr
.hpp
View file @
f84e2020
...
...
@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dq_dk_dv_pipeline_qs_ks_vr_dos
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
DQDKDV
Pipeline
QSKSVROGradS
DefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipeline
QSKSVROGradS
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipeline
KRKTRVR
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
...
...
@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
HotLoopScheduler
=
typename
Policy
::
template
HotLoopScheduler
<
Problem
>;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
...
@@ -46,14 +48,6 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
true
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
false
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
true
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
...
...
@@ -61,7 +55,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
k
HasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
k
IsDeterministic
=
Problem
::
kIsDeterministic
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
1
;
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
...
...
@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"
qs_ks_vr_dos
"
;
static
constexpr
const
char
*
name
=
"
kr_ktr_vr
"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
...
...
@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
/*qt_dram_block_window_tmp*/
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
/*kt_dram_block_window_tmp*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
/*dot_dram_block_window_tmp*/
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
...
...
@@ -122,13 +107,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -138,9 +121,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
...
...
@@ -156,77 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// QT tile in LDS
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptorAsQT
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptorAsKT
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
{
0
,
0
});
// OGradT tile in LDS
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptorAsOGradT
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
...
...
@@ -234,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
// K, HBM ->LDS ->Reg
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
// Early termination
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
...
...
@@ -274,217 +169,408 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
}
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKRegSliceBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
VDataType
*
v_lds_ptr
=
static_cast
<
VDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto
shuffled_k_block_tile
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKRegWriteBlockDescriptor
<
Problem
>());
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
shuffled_k_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
auto
kt_lds_read_window
=
make_tile_window
(
kt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeKTRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
v_block_tile
=
load_tile
(
v_dram_window
);
store_tile
(
k_lds_write_window
,
k_block_tile
);
shuffle_tile
(
shuffled_k_block_tile
,
k_block_tile
);
store_tile
(
shuffled_k_lds_write_window
,
shuffled_k_block_tile
);
block_sync_lds
();
k_reg_tensor
=
load_tile
(
k_lds_read_window
);
block_sync_lds
();
auto
kt_reg_tensor
=
load_tile
(
kt_lds_read_window
);
store_tile
(
k
_lds_window
,
k
_block_tile
);
// // persistent K in LDS
store_tile
(
v
_lds_
write_
window
,
v
_block_tile
);
auto
q_dram_block_window
=
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
block_sync_lds
();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
auto
do_dram_block_window
=
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
q_lds_window
.
get_window_origin
(),
Policy
::
template
MakeQRegSliceBlockDescriptor
<
Problem
>());
auto
pt_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakePTRegSliceBlockDescriptor
<
Problem
>());
// QT: Reg -> Reg-> LDS
auto
shuffled_q_block_tile
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQRegWriteBlockDescriptor
<
Problem
>());
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
shuffled_q_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
auto
qt_lds_read_window
=
make_tile_window
(
qt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeQTRegSliceBlockDescriptor
<
Problem
>());
// dO: HBM ->Reg ->LDS
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
do_lds_window
.
get_window_origin
(),
Policy
::
template
MakeOGradRegSliceBlockDescriptor
<
Problem
>());
// dOT: Reg ->Reg ->LDS
auto
shuffled_do_block_tile
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradRegWriteBlockDescriptor
<
Problem
>());
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
shuffled_do_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
auto
dot_lds_read_window
=
make_tile_window
(
dot_read_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeOGradTRegSliceBlockDescriptor
<
Problem
>());
// dS: Reg -> Reg -> LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
ds_lds_read_window
=
make_tile_window
(
ds_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK4
>
{}),
ds_lds_window
.
get_window_origin
(),
Policy
::
template
MakeSGradRegSliceBlockDescriptor
<
Problem
>());
auto
dst_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakeSGradTRegSliceBlockDescriptor
<
Problem
>());
// Bias: HBM ->Reg ->Reg ->LDS
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})},
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
BiasDataType
*
bias_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
bias_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
bias_lds_ptr
,
Policy
::
template
MakeBiasLdsBlockDescriptor
<
Problem
>());
auto
bias_lds_write_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
bias_s_lds_read_window
=
make_tile_window
(
bias_lds_write_window
.
get_bottom_tensor_view
(),
bias_lds_write_window
.
get_window_lengths
(),
bias_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeBiasSTileDistribution
<
decltype
(
gemm_0
)>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// LSE: HBM -> LDS ->Reg
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
l
se
_dram_block_window
.
get_window_origin
()
,
lse_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window
_tmp
.
get_window_lengths
(),
{
se
qlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
LSEDataType
*
lse_lds_ptr
=
static_cast
<
LSEDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
lse_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
lse_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_lds_write_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
lse_lds_read_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// D: HBM ->Reg
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
()
,
d_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window
_tmp
.
get_window_lengths
(),
{
seqlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
DDataType
*
d_lds_ptr
=
static_cast
<
DDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()));
auto
d_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
d_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
d_lds_write_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
auto
d_lds_read_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// RandVal: HBM ->Reg
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>(
randval_dram_block_window_tmp
,
seqlen_q_start
);
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
dbias_lds_read_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
// ----------------------------Loop write out------------------------------//
auto
dq_dram_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
using
SPBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
==
kK0
,
"kQKHeaddim should equal to kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
==
kK2
,
"kVHeaddim should equal to kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{}
;
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
)
;
__builtin_amdgcn_sched_barrier
(
0
);
// Hot loop
while
(
i_total_loops
<
num_total_loop
)
{
auto
q_block_tile
=
load_tile
(
q_dram_window
);
clear_tile
(
st_acc
);
// Initialize S^T
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
get_slice_tile
(
q_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
});
}
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
get_slice_tile
(
q_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
block_sync_lds
();
}
// STAGE 1, Q@K Gemm0
auto
s_acc
=
SPBlockTileType
{};
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
(
);
auto
bias_
shuffle
_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffle
d_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_
shuffle
_tmp
,
bias_tile
);
store_tile
(
bias
t
_lds_
shuffl
e_window
,
bias_
shuffle
_tmp
);
shuffle_tile
(
shuffle
d_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_
writ
e_window
,
shuffle
d_bias_tile
);
block_sync_lds
();
auto
bias
t
_tile
=
load_tile
(
bias
t
_lds_window
);
auto
bias
_s
_tile
=
load_tile
(
bias
_s
_lds_
read_
window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
s
t
_acc
,
bias
t
_tile
);
s_acc
,
bias
_s
_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s
t
_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s
t
_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
...
...
@@ -499,157 +585,162 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
}
};
auto
p
t
=
SP
T
BlockTileType
{};
constexpr
auto
p
t
_spans
=
decltype
(
p
t
)
::
get_distributed_spans
();
sweep_tile_span
(
p
t
_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
p
t
_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
t
(
i_j_idx
)
=
exp2
(
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
t
(
i_j_idx
)
=
exp2
(
scale
*
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_st
art
+
i_total_loops
*
kM0
,
p
t
,
randval_dram_window
);
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_st
ep
,
k_origin
.
at
(
number
<
0
>
{})
,
p
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
// store the prefetch
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
t
);
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
t
);
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
static_for
<
0
,
k1_loops
,
1
>
{}([
&
](
auto
i_k1
)
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
get_slice_tile
(
dot_lds_window
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kVHeaddim
,
(
i_k1
+
1
)
*
kK1
>
{}));
// STAGE 3, P^T@OGrad^T Gemm1
auto
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
store_tile
(
d_lds_write_window
,
d_block_tile
);
block_sync_lds
();
});
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
static_for
<
0
,
k2_loops
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
get_slice_tile
(
do_lds_window
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kM0
,
(
i_k2
+
1
)
*
kK2
>
{}),
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
// STAGE 4, OGrad@V Gemm2
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
block_sync_lds
();
});
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dp_acc
=
SPGradBlockTileType
{};
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds
t
_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
t
[
i_j_idx
]
>=
0
;
ds
t
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
t
=
[
&
]()
{
if
constexpr
(
kHa
sDropout
)
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
t
);
ds
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
t
);
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias
t
_lds_
shuffl
e_window
,
dbias
t
);
store_tile
(
bias_lds_
writ
e_window
,
dbias
);
block_sync_lds
();
auto
dbias
t
_tile
=
load_tile
(
dbias
t
_lds_
shuffle
_window
);
auto
dbias
t_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
auto
shuffled_
dbias_tile
=
load_tile
(
dbias_lds_
read
_window
);
auto
dbias
_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
block_sync_lds
();
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
static_for
<
0
,
k3_loops
,
1
>
{}([
&
](
auto
i_k3
)
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
get_slice_tile
(
qt_lds_window
,
sequence
<
0
,
i_k3
*
kK3
>
{},
sequence
<
kQKHeaddim
,
(
i_k3
+
1
)
*
kK3
>
{}));
block_sync_lds
();
});
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
// STAGE7 SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{})
)
;
}
);
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
// QGrad Scale
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
...
...
@@ -658,34 +749,33 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
if
constexpr
(
kIsDeterministic
)
{
store_tile
(
dq_dram_window
,
dq_acc
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
}
//
KGrad
Scale
if
constexpr
(
kHa
sDropout
)
//
Results
Scale
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
s
_kt
s
_vr.hpp
→
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
r
_kt
r
_vr
_iglp
.hpp
View file @
f84e2020
...
...
@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dq_dk_dv_pipeline_ks_kts_vr
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
DQDKDV
Pipeline
KSKTSVR
DefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineK
S
KT
S
VR
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineK
R
KT
R
VR
IGLP
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
...
...
@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
HotLoopScheduler
=
typename
Policy
::
template
HotLoopScheduler
<
Problem
>;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
...
@@ -46,14 +48,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
true
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
false
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
...
...
@@ -61,7 +55,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
k
HasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
k
IsDeterministic
=
Problem
::
kIsDeterministic
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
1
;
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
...
...
@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"k
s
_kt
s
_vr"
;
static
constexpr
const
char
*
name
=
"k
r
_kt
r
_vr
_iglp
"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
...
...
@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
qt_dram_block_window_tmp
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
kt_dram_block_window_tmp
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
dot_dram_block_window_tmp
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
...
...
@@ -122,43 +107,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
QTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
KTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kVHeaddim
==
OGradTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
...
...
@@ -166,83 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// QT tile in LDS
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsBlockDescriptor
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kK3
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsBlockDescriptor
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
// OGradT tile in LDS
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsBlockDescriptor
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kK1
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
...
...
@@ -250,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
// K, HBM ->LDS ->Reg
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
// Early termination
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
...
...
@@ -290,272 +169,444 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
}
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKRegSliceBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
VDataType
*
v_lds_ptr
=
static_cast
<
VDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto
shuffled_k_block_tile
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKRegWriteBlockDescriptor
<
Problem
>());
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
shuffled_k_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
kt_dram_block_window
=
kt_dram_block_window_tmp
;
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram_block_window
.
get_bottom_tensor_view
(),
kt_dram_block_window
.
get_window_lengths
(),
kt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKTDramTileDistribution
<
Problem
>());
// K^T DRAM tile window for
// load
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
auto
kt_block_tile
=
load_tile
(
kt_dram_window
);
auto
kt_lds_read_window
=
make_tile_window
(
kt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeKTRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
v_block_tile
=
load_tile
(
v_dram_window
);
store_tile
(
k_lds_write_window
,
k_block_tile
);
shuffle_tile
(
shuffled_k_block_tile
,
k_block_tile
);
store_tile
(
shuffled_k_lds_write_window
,
shuffled_k_block_tile
);
block_sync_lds
();
k_reg_tensor
=
load_tile
(
k_lds_read_window
);
block_sync_lds
();
auto
kt_reg_tensor
=
load_tile
(
kt_lds_read_window
);
auto
kt_shuffle_tmp
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKTRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
kt_shuffle_tmp
,
kt_block_tile
);
store_tile
(
v_lds_write_window
,
v_block_tile
);
store_tile
(
kt_lds_window
,
kt_shuffle_tmp
);
// persistent K^T in LDS
block_sync_lds
();
auto
q_dram_block_window
=
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()));
auto
qt_dram_block_window
=
make_tile_window
(
qt_dram_block_window_tmp
.
get_bottom_tensor_view
(),
qt_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
q_lds_window
.
get_window_origin
(),
Policy
::
template
MakeQRegSliceBlockDescriptor
<
Problem
>());
auto
pt_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakePTRegSliceBlockDescriptor
<
Problem
>());
// QT: Reg -> Reg-> LDS
auto
shuffled_q_block_tile
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQRegWriteBlockDescriptor
<
Problem
>());
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
shuffled_q_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
auto
qt_lds_read_window
=
make_tile_window
(
qt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeQTRegSliceBlockDescriptor
<
Problem
>());
auto
do_dram_block_window
=
// dO: HBM ->Reg ->LDS
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
auto
dot_dram_block_window
=
make_tile_window
(
dot_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dot_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
do_lds_window
.
get_window_origin
(),
Policy
::
template
MakeOGradRegSliceBlockDescriptor
<
Problem
>());
// dOT: Reg ->Reg ->LDS
auto
shuffled_do_block_tile
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradRegWriteBlockDescriptor
<
Problem
>());
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
shuffled_do_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
auto
dot_lds_read_window
=
make_tile_window
(
dot_read_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeOGradTRegSliceBlockDescriptor
<
Problem
>());
// dS: Reg -> Reg -> LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
ds_lds_read_window
=
make_tile_window
(
ds_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK4
>
{}),
ds_lds_window
.
get_window_origin
(),
Policy
::
template
MakeSGradRegSliceBlockDescriptor
<
Problem
>());
auto
dst_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakeSGradTRegSliceBlockDescriptor
<
Problem
>());
// Bias: HBM ->Reg ->Reg ->LDS
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})},
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
BiasDataType
*
bias_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
bias_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
bias_lds_ptr
,
Policy
::
template
MakeBiasLdsBlockDescriptor
<
Problem
>());
auto
bias_lds_write_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
qt_dram
_window
=
make_tile_window
(
qt_dram_block
_window
.
get_bottom_tensor_view
(),
qt_dram_block
_window
.
get_window_lengths
(),
qt_dram_block
_window
.
get_window_origin
(),
Policy
::
template
Make
QTDram
TileDistribution
<
Problem
>());
auto
bias_s_lds_read
_window
=
make_tile_window
(
bias_lds_write
_window
.
get_bottom_tensor_view
(),
bias_lds_write
_window
.
get_window_lengths
(),
bias_lds_write
_window
.
get_window_origin
(),
Policy
::
template
Make
BiasS
TileDistribution
<
decltype
(
gemm_0
)
>());
auto
dot_dram_window
=
make_tile_window
(
dot_dram_block_window
.
get_bottom_tensor_view
(),
dot_dram_block_window
.
get_window_lengths
(),
dot_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradTDramTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// LSE: HBM -> LDS ->Reg
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
l
se
_dram_block_window
.
get_window_origin
()
,
lse_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window
_tmp
.
get_window_lengths
(),
{
se
qlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
LSEDataType
*
lse_lds_ptr
=
static_cast
<
LSEDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
lse_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
lse_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_lds_write_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
lse_lds_read_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// D: HBM ->Reg
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
()
,
d_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window
_tmp
.
get_window_lengths
(),
{
seqlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
DDataType
*
d_lds_ptr
=
static_cast
<
DDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()));
auto
d_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
d_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
d_lds_write_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
()
,
biast_lds_shuffle_window
.
get_window_lengths
(
),
biast_lds_shuffle_window
.
get_window_origin
()
,
Policy
::
template
Make
BiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
d_lds_read_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}
),
{
0
}
,
Policy
::
template
Make
LSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
// RandVal: HBM ->Reg
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>(
randval_dram_block_window_tmp
,
seqlen_q_start
);
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
dbias_lds_read_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
// ----------------------------Loop write out------------------------------//
auto
dq_dram_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
using
SPBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
==
kK0
,
"kQKHeaddim should equal to kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
==
kK2
,
"kVHeaddim should equal to kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
/*
* Prefetch Q, LSE, dO, D
*/
auto
q_block_tile
=
load_tile
(
q_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
clear_tile
(
st_acc
);
// Initialize S^T
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write 0
q_block_tile
=
load_tile
(
q_dram_window
);
// global read 1
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
/*
* Store prefetched data into LDS
*/
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
});
}
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
const
auto
dot_prefetch
=
load_tile
(
dot_dram_window
);
// prefetch load OGrad^T tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kN0
,
(
k0_loops
-
1
)
*
kK0
>
{}));
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
store_tile
(
q
_lds_window
,
q
_block_tile
);
store_tile
(
d
_lds_
write_
window
,
d
_block_tile
);
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
}
/*
* Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
__builtin_amdgcn_sched_barrier
(
0
);
// Hot loop
while
(
i_total_loops
<
(
num_total_loop
-
1
))
{
// STAGE 1, Q@K Gemm0
auto
s_acc
=
SPBlockTileType
{};
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
0
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
(
);
auto
bias_
shuffle
_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffle
d_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_
shuffle
_tmp
,
bias_tile
);
store_tile
(
bias
t
_lds_
shuffl
e_window
,
bias_
shuffle
_tmp
);
shuffle_tile
(
shuffle
d_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_
writ
e_window
,
shuffle
d_bias_tile
);
block_sync_lds
();
auto
bias
t
_tile
=
load_tile
(
bias
t
_lds_window
);
auto
bias
_s
_tile
=
load_tile
(
bias
_s
_lds_
read_
window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
s
t
_acc
,
bias
t
_tile
);
s_acc
,
bias
_s
_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s
t
_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s
t
_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
...
...
@@ -570,278 +621,416 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
};
auto
p
t
=
SP
T
BlockTileType
{};
constexpr
auto
p
t
_spans
=
decltype
(
p
t
)
::
get_distributed_spans
();
sweep_tile_span
(
p
t
_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
p
t
_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
t
(
i_j_idx
)
=
exp2
(
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
t
(
i_j_idx
)
=
exp2
(
scale
*
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
auto
dot_shuffle_tmp
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
dot_shuffle_tmp
,
dot_prefetch
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
if
constexpr
(
kHasDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_st
art
+
i_total_loops
*
kM0
,
p
t
,
randval_dram_window
);
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_st
ep
,
k_origin
.
at
(
number
<
0
>
{})
,
p
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
t
);
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
t
);
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
dot
=
load_tile
(
dot_dram_window
);
// load next OGrad^T
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
dot_lds_window
);
// STAGE 3, P^T@OGrad^T Gemm1
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
block_sync_lds
();
shuffle_tile
(
dot_shuffle_tmp
,
dot
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
store_tile
(
d_lds_write_window
,
d_block_tile
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
);
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
// tail
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias_lds_write_window
,
dbias
);
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
(
k1_loops
-
1
)
*
kK1
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
auto
shuffled_dbias_tile
=
load_tile
(
dbias_lds_read_window
);
auto
dbias_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
lse
=
load_tile
(
lse_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE7 SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
d
=
load_tile
(
d_lds_read_window
);
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
// QGrad Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
if
constexpr
(
kIsDeterministic
)
{
store_tile
(
dq_dram_window
,
dq_acc
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write 0
do_block_tile
=
load_tile
(
do_dram_window
);
// global read 1
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
}
__builtin_amdgcn_sched_barrier
(
0
);
// Tail
auto
s_acc
=
SPBlockTileType
{};
// STAGE 1, Q@K Gemm0
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
if
constexpr
(
k2_loops
>
2
)
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
static_for
<
0
,
k2_loops
-
2
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffled_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
shuffled_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_write_window
,
shuffled_bias_tile
);
block_sync_lds
();
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
auto
bias_s_tile
=
load_tile
(
bias_s_lds_read_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
},
s_acc
,
bias_s_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write i + 1
do_block_tile
=
load_tile
(
do_dram_window
);
// global read i + 2
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
const
auto
qt_prefetch
=
load_tile
(
qt_dram_window
);
// prefetch load Q^T tile
{
// tail
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
2
)
*
kK2
>
{},
sequence
<
kN0
,
(
k2_loops
-
1
)
*
kK2
>
{}));
block_sync_lds
();
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
store_tile
(
do_lds_window
,
do_block_tile
);
block_sync_lds
();
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
1
)
*
kK2
>
{},
sequence
<
kN0
,
k2_loops
*
kK2
>
{}));
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
});
});
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
p
,
randval_dram_window
);
}
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds
t
_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
t
[
i_j_idx
]
>=
0
;
ds
t
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
t
=
[
&
]()
{
if
constexpr
(
kHa
sDropout
)
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
t
);
ds
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
t
);
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias
t
_lds_
shuffl
e_window
,
dbias
t
);
store_tile
(
bias_lds_
writ
e_window
,
dbias
);
block_sync_lds
();
auto
dbias
t
_tile
=
load_tile
(
dbias
t
_lds_
shuffle
_window
);
auto
dbias
t_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
auto
shuffled_
dbias_tile
=
load_tile
(
dbias_lds_
read
_window
);
auto
dbias
_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_shuffle_tmp
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
qt_shuffle_tmp
,
qt_prefetch
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
if
constexpr
(
k3_loops
>
1
)
{
static_for
<
0
,
k3_loops
-
1
,
1
>
{}([
&
](
auto
i_k3
)
{
const
auto
qt
=
load_tile
(
qt_dram_window
);
// load next Q^T
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
shuffle_tile
(
qt_shuffle_tmp
,
qt
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
});
}
// tail
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
(
k3_loops
-
1
)
*
kK3
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
block_sync_lds
();
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
});
// QGrad Scale
if
constexpr
(
kHasDropout
)
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
}
);
}
else
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
(
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
});
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
//
KGrad
Scale
if
constexpr
(
kHa
sDropout
)
//
Results
Scale
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsDeterministic
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
store_tile
(
dq_dram_window
,
dq_acc
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
deleted
100644 → 0
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k & k^t located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
true
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
deleted
100644 → 0
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineKSVR
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
using
KGradDataType
=
remove_cvref_t
<
typename
Problem
::
KGradDataType
>
;
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK2
=
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK3
=
BlockFmhaShape
::
kK3
;
static
constexpr
index_t
kK4
=
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
false
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
false
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentVGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"ks_vr"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
typename
BiasGradDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
qt_dram_block_window_tmp
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
/*kt_dram_block_window_tmp*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
dot_dram_block_window_tmp
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
const
BiasGradDramBlockWindowTmp
&
dbias_dram_block_window_tmp
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
QTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kVHeaddim
==
OGradTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// QT tile in LDS
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsBlockDescriptor
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kK3
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptorAsKT
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
// OGradT tile in LDS
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsBlockDescriptor
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kK1
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
constexpr
auto
gemm_2
=
Policy
::
template
GetOGradVBlockGemm
<
Problem
>();
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_q_end
-
seqlen_q_start
,
kM0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
}
auto
k_block_tile
=
load_tile
(
k_dram_window
);
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
auto
q_dram_block_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
qt_dram_block_window
=
make_tile_window
(
qt_dram_block_window_tmp
.
get_bottom_tensor_view
(),
qt_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
do_dram_block_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
dot_dram_block_window
=
make_tile_window
(
dot_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dot_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
qt_dram_window
=
make_tile_window
(
qt_dram_block_window
.
get_bottom_tensor_view
(),
qt_dram_block_window
.
get_window_lengths
(),
qt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQTDramTileDistribution
<
Problem
>());
auto
dot_dram_window
=
make_tile_window
(
dot_dram_block_window
.
get_bottom_tensor_view
(),
dot_dram_block_window
.
get_window_lengths
(),
dot_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradTDramTileDistribution
<
Problem
>());
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
lse_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
auto
q_block_tile
=
load_tile
(
q_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
clear_tile
(
st_acc
);
// Initialize S^T
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write 0
q_block_tile
=
load_tile
(
q_dram_window
);
// global read 1
}
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
});
}
const
auto
dot_prefetch
=
load_tile
(
dot_dram_window
);
// prefetch load OGrad^T tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kN0
,
(
k0_loops
-
1
)
*
kK0
>
{}));
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
store_tile
(
biast_lds_shuffle_window
,
bias_shuffle_tmp
);
block_sync_lds
();
auto
biast_tile
=
load_tile
(
biast_lds_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
st_acc
,
biast_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
st_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
st_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
auto
pt
=
SPTBlockTileType
{};
constexpr
auto
pt_spans
=
decltype
(
pt
)
::
get_distributed_spans
();
sweep_tile_span
(
pt_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
pt
(
i_j_idx
)
=
exp2
(
st_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
pt
(
i_j_idx
)
=
exp2
(
scale
*
st_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
auto
dot_shuffle_tmp
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
dot_shuffle_tmp
,
dot_prefetch
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_start
+
i_total_loops
*
kM0
,
pt
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
pt
);
}
}();
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
dot
=
load_tile
(
dot_dram_window
);
// load next OGrad^T
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
shuffle_tile
(
dot_shuffle_tmp
,
dot
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
});
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
// tail
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
(
k1_loops
-
1
)
*
kK1
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
}
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
{
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write 0
do_block_tile
=
load_tile
(
do_dram_window
);
// global read 1
}
if
constexpr
(
k2_loops
>
2
)
{
static_for
<
0
,
k2_loops
-
2
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write i + 1
do_block_tile
=
load_tile
(
do_dram_window
);
// global read i + 2
});
}
const
auto
qt_prefetch
=
load_tile
(
qt_dram_window
);
// prefetch load Q^T tile
{
// tail
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
2
)
*
kK2
>
{},
sequence
<
kN0
,
(
k2_loops
-
1
)
*
kK2
>
{}));
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
1
)
*
kK2
>
{},
sequence
<
kN0
,
k2_loops
*
kK2
>
{}));
}
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
dst_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
pt
[
i_j_idx
]
>=
0
;
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbiast
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
dst
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
dst
);
}
}();
store_tile
(
biast_lds_shuffle_window
,
dbiast
);
block_sync_lds
();
auto
dbiast_tile
=
load_tile
(
dbiast_lds_shuffle_window
);
auto
dbiast_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_shuffle_tmp
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
qt_shuffle_tmp
,
qt_prefetch
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
if
constexpr
(
k3_loops
>
1
)
{
static_for
<
0
,
k3_loops
-
1
,
1
>
{}([
&
](
auto
i_k3
)
{
const
auto
qt
=
load_tile
(
qt_dram_window
);
// load next Q^T
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
shuffle_tile
(
qt_shuffle_tmp
,
qt
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
});
}
// tail
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
(
k3_loops
-
1
)
*
kK3
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
block_sync_lds
();
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
});
// QGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
// KGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
deleted
100644 → 0
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
deleted
100644 → 0
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, q & k & do located in lds.
using
BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
true
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
true
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
f84e2020
...
...
@@ -11,6 +11,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -18,60 +20,215 @@
namespace
ck_tile
{
template
<
bool
QLoadOnce_
,
bool
QTLoadOnce_
,
bool
KLoadOnce_
,
bool
KTLoadOnce_
,
bool
VLoadOnce_
,
bool
OGradLoadOnce_
,
bool
OGradTLoadOnce_
>
struct
BlockFmhaBwdPipelineDefaultPolicy
{
static
constexpr
bool
QLoadOnce
=
QLoadOnce_
;
// if q load whole block length (qkhdim) to LDS at once
static
constexpr
bool
QTLoadOnce
=
QTLoadOnce_
;
// if q^t load whole block length (qkhdim) to LDS at once
static
constexpr
bool
KLoadOnce
=
KLoadOnce_
;
// if k load whole block length (qkhdim) to LDS at once
static
constexpr
bool
KTLoadOnce
=
KTLoadOnce_
;
// if k^t load whole block length (qkhdim) to LDS at once
static
constexpr
bool
VLoadOnce
=
VLoadOnce_
;
// if v load whole block length (vhdim) to Vgprs at once
static
constexpr
bool
OGradLoadOnce
=
OGradLoadOnce_
;
// if do load whole block length (vhdim) to LDS at once
static
constexpr
bool
OGradTLoadOnce
=
OGradTLoadOnce_
;
// if do^t load whole block length (vhdim) to LDS at once
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
2
>
{}),
false
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{})
==
16
?
false
:
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
2
>
{}),
false
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{})
==
16
?
false
:
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
false
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// these are for global load
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentV
()
{
if
constexpr
(
VLoadOnce
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
}
else
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
total_pixels
>
kMaxVecLoad
?
kMaxVecLoad
:
total_pixels
;
}
template
<
typename
Problem
>
...
...
@@ -85,19 +242,38 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOGrad
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
QGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
Bias
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
BiasDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
...
...
@@ -128,60 +304,35 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentQ
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentK
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentOGrad
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
...
...
@@ -193,1151 +344,1577 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
32
)
return
8
;
else
return
4
;
return
total_pixels
/
GetAlignmentBias
<
Problem
>
();
}
// these are for lds
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
SmemKPackQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
AlignmentPostQGradAcc
()
{
// TODO: this is for 3d layout
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
16
/
sizeof
(
AccDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
SmemKPackK
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
AlignmentPostQGrad
()
{
// TODO: this is for 3d layout
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
return
GetAlignmentPostQGradAcc
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackV
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
// TODO: this is for 3d layout
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
// TODO: this is for 3d layout
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
return
16
/
sizeof
(
BiasDataType
);
}
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
{
// TODO: this is for 3d layout
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N2
=
kNPerBlock
/
(
N1
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackSGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
{
// TODO: this is for 3d layout
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
return
16
/
sizeof
(
GemmDataType
);
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVInRegDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WG
::
k
N
)
;
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
k
M0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK
0
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M1
=
get_warp_size
()
/
K0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
v_block_dstr
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
// 3d + padding
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradDramTileDistribution
()
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
MNPerBlock
),
make_merge_transform
(
make_tuple
(
KPerBlock
/
KPack
,
KPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
x_lds_block_desc
;
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// 3d + padding
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptorAsXT
()
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
MNPerBlock
),
make_merge_transform
(
make_tuple
(
KPerBlock
/
KPack
,
KPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M1
=
get_warp_size
()
/
K0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
xt_lds_block_desc
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
,
index_t
PixelsPerRow
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
XTLdsBlockDescriptor
()
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LSEDDramTileDistribution
()
{
static_assert
(
PixelsPerRow
%
KPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
KPack
;
static_assert
(
MNPerBlock
%
NPerRow
==
0
);
static_assert
(
KPerBlock
%
KPack
==
0
);
constexpr
auto
xt_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
KPack
)
>
{},
number
<
PixelsPerRow
+
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
xt_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
MNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
return
xt_lds_block_desc
;
}
// Duplicate dimension
constexpr
index_t
N0
=
NWarp
;
constexpr
index_t
N1
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
(
get_warp_size
()
/
kMPerBlock
)
:
1
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
M0
=
MWarp
;
constexpr
index_t
M1
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
kMPerBlock
:
get_warp_size
();
constexpr
index_t
M2
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
1
:
(
kMPerBlock
/
get_warp_size
());
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
>
,
sequence
<
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QLdsBlockDescriptorAsQT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
BiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
k
K
PerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK
0
;
}
();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
(
);
constexpr
index_t
k
N
PerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetAlignmentBias
<
Problem
>
()
;
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M1
=
get_warp_size
()
/
N
0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
MakeXLdsBlockDescriptorAsXT
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KLdsBlockDescriptor
()
template
<
typename
DataType
,
index_t
MPerBlock
,
index_t
KPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreXDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
K1
=
16
/
sizeof
(
DataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
1
;
constexpr
index_t
M1
=
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KLdsBlockDescriptorAsKT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreODramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
return
Make
XLdsBlockDescriptorAsXT
<
kNPer
Block
,
kKPerBlock
,
kKPack
>
();
return
Make
PreXDramTileDistribution
<
ODataType
,
k
Block
Size
,
kKPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreOGradDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
return
MakePreXDramTileDistribution
<
OGradDataType
,
kBlockSize
,
kKPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
OGradLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PostQGradAccDramTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kQKHeaddim
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradLdsBlockDescriptorAsOGradT
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
M2
);
return
MakeXLdsBlockDescriptorAsXT
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
1
>
,
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
3
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
SGradLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PostQGradDramTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackSGrad
<
Problem
>
();
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kQKHeaddim
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQTLdsBlockDescriptor
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
QDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
M2
);
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// these are for lds
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
KDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
GetAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQT
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
GetTransposedAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackK
()
{
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kMPerBlock
%
kKPack
==
0
);
constexpr
auto
biast_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
biast_lds_block_desc
=
transform_tensor_descriptor
(
biast_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
biast_lds_block_desc
;
return
GetAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
SizeQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPackKT
()
{
constexpr
index_t
smem_size_q
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_q
;
return
GetTransposedAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
SizeQT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPackV
()
{
constexpr
index_t
smem_size_qt
=
[
&
]()
{
if
constexpr
(
QLoadOnce
&&
!
QTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_qt
;
return
GetAlignmentV
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
constexpr
index_t
smem_size_k
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_k
;
return
GetAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
T
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
T
()
{
constexpr
index_t
smem_size_kt
=
[
&
]()
{
if
constexpr
(
KLoadOnce
&&
!
KTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_kt
;
return
GetTransposedAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeV
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
{
constexpr
index_t
smem_size_v
=
[
&
]()
{
if
constexpr
(
VLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
VDataType
)
*
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_v
;
return
GetAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
Size
OGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPack
OGrad
T
()
{
constexpr
index_t
smem_size_do
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_do
;
return
GetTransposedAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
SizeO
Grad
T
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPackS
Grad
()
{
constexpr
index_t
smem_size_dot
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
&&
!
OGradTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_dot
;
// TODO: this is for 3d layout
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
return
16
/
sizeof
(
GemmDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeSGrad
()
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
{
constexpr
index_t
smem_size_ds
=
sizeof
(
typename
Problem
::
GemmDataType
)
*
MakeSGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_ds
;
constexpr
auto
DataTypeSize
=
2
;
// sizeof(F16/BF16)
constexpr
auto
MNLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{},
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MNLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
x_lds_block_desc_permuted
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
x_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
x_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MNPerBlock
/
MNLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
MNLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
x_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeBias
()
template
<
typename
Problem
,
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
,
index_t
KPackT
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXTLdsBlockDescriptor
()
{
constexpr
index_t
smem_size_bias
=
[
&
]()
{
if
constexpr
(
Problem
::
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
sizeof
(
typename
Problem
::
BiasDataType
)
*
MakeBiasTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
else
return
0
;
}();
return
smem_size_bias
;
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
MNPerXDL
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
constexpr
auto
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
auto
MN0
=
MNPerBlock
/
KPack
;
constexpr
auto
MN1
=
KPack
;
constexpr
auto
KThreadWrite
=
kBlockSize
/
MN0
;
constexpr
auto
K0Number
=
KPerBlock
/
KPackT
;
constexpr
auto
K0PerThreadWrite
=
K0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
get_warp_size
()
/
MNPerXDL
;
// assume 32x32x8 mfma
constexpr
auto
K0PerThreadRead
=
K0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
KPackT
*
MN0
*
2
>
128
)
?
1
:
128
/
(
KPackT
*
MN0
*
2
);
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mnpair<=n0
constexpr
auto
mnpair
=
(
KPackT
*
MNPerXDL
*
2
>
128
)
?
1
:
((
128
/
(
KPackT
*
MNPerXDL
*
2
))
>
MN0
?
MN0
:
128
/
(
KPackT
*
MNPerXDL
*
2
));
constexpr
auto
xt_lds_block_desc_raw
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
MN1
>
{},
number
<
kfold
*
MN0
/
mnpair
>
{},
number
<
mnpair
>
{},
KPackT
),
make_tuple
(
number
<
KPackT
*
kfold
*
MN0
*
KThreadReadPerm
*
MN1
*
K0PerThreadWrite
>
{},
number
<
KPackT
*
kfold
*
MN0
*
KThreadReadPerm
*
MN1
>
{},
number
<
KPackT
*
kfold
*
MN0
>
{},
number
<
KPackT
*
mnpair
>
{},
number
<
KPackT
>
{},
number
<
1
>
{}),
number
<
KPackT
>
{},
number
<
1
>
{});
constexpr
auto
xt_lds_block_desc_permuted
=
transform_tensor_descriptor
(
xt_lds_block_desc_raw
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
MN1
>
{},
number
<
kfold
*
MN0
/
mnpair
>
{})),
make_pass_through_transform
(
number
<
mnpair
>
{}),
make_pass_through_transform
(
KPackT
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
xt_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
xt_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
MN1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
MN0
/
mnpair
>
{})),
make_pass_through_transform
(
number
<
mnpair
>
{}),
make_pass_through_transform
(
KPackT
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
xt_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KPackT
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MN0
/
mnpair
>
{},
number
<
mnpair
>
{},
number
<
MN1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
xt_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
{
constexpr
index_t
smem_size_q
=
GetSmemSizeQ
<
Problem
>
();
constexpr
index_t
smem_size_qt
=
GetSmemSizeQT
<
Problem
>
();
constexpr
index_t
smem_size_k
=
GetSmemSizeK
<
Problem
>
();
constexpr
index_t
smem_size_kt
=
GetSmemSizeKT
<
Problem
>
();
constexpr
index_t
smem_size_v
=
GetSmemSizeV
<
Problem
>
();
constexpr
index_t
smem_size_do
=
GetSmemSizeOGrad
<
Problem
>
();
constexpr
index_t
smem_size_dot
=
GetSmemSizeOGradT
<
Problem
>
();
constexpr
index_t
smem_size_ds
=
GetSmemSizeSGrad
<
Problem
>
();
constexpr
index_t
smem_size_bias
=
GetSmemSizeBias
<
Problem
>
();
constexpr
index_t
smem_size_transpose
=
max
(
smem_size_ds
,
smem_size_bias
);
index_t
smem_size
=
0
;
if
constexpr
(
QLoadOnce
&&
OGradLoadOnce
)
smem_size
+=
smem_size_q
+
smem_size_qt
+
smem_size_do
+
smem_size_dot
+
smem_size_transpose
;
// 1~4 & 10
else
if
(
QLoadOnce
&&
!
OGradLoadOnce
&&
!
OGradTLoadOnce
)
smem_size
+=
smem_size_q
+
smem_size_qt
+
max
(
smem_size_do
,
smem_size_dot
,
smem_size_transpose
);
// 5/7/11 TODO: Multiple buffers strategy
else
if
(
!
QLoadOnce
&&
!
QTLoadOnce
&&
OGradLoadOnce
)
smem_size
+=
smem_size_do
+
smem_size_dot
+
max
(
smem_size_q
,
smem_size_qt
,
smem_size_transpose
);
// 6/8/12 TODO: Multiple buffers strategy
else
if
(
!
QLoadOnce
&&
!
QTLoadOnce
&&
!
OGradLoadOnce
&&
!
OGradTLoadOnce
)
smem_size
+=
max
(
smem_size_q
,
smem_size_qt
,
smem_size_do
,
smem_size_dot
,
smem_size_transpose
);
// 9/13 TODO: Multiple buffers strategy
// 14/15 needs to be adjusted
if
constexpr
(
KLoadOnce
)
smem_size
+=
(
smem_size_k
+
smem_size_kt
);
// 1~13
else
smem_size
=
max
(
smem_size_k
,
smem_size_kt
,
smem_size
);
// 14/15 TODO: Multiple buffers strategy
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
max
(
smem_size
,
smem_size_v
);
// 15
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LSEDDramTileDistribution
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
N1
=
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
;
constexpr
index_t
N0
=
NWarp
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
M4
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
*
2
;
constexpr
index_t
M3
=
WG
::
WarpGemmAttribute
::
Impl
::
kCMLane
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM0PerLane
/
2
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
WG
::
WarpGemmAttribute
::
Impl
::
kM
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
,
M4
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
3
,
1
>>
,
sequence
<
1
,
1
,
1
>
,
sequence
<
0
,
2
,
4
>>
{});
constexpr
auto
k_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KRegBlockDescriptor
()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
return
make_static_tile
_d
i
str
ibution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N
0
,
N1
,
N2
>
,
sequence
<
K
0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
auto
k_block_outer
_dstr
_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
N
IterPerWarp
,
NWarp
>
,
sequence
<
K
IterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VLdsWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
}
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VRegBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
return
make_static_tile
_d
i
str
ibution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N
0
,
N1
,
N2
>
,
sequence
<
K
0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
auto
v_block_outer
_dstr
_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
N
IterPerWarp
,
NWarp
>
,
sequence
<
K
IterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
OGradDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledKRegWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
GetAlignment
OGrad
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignment
K
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
N2
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M
0
,
M
1
,
M
2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
N
0
,
N
1
,
N
2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
2
>>
{});
}
template
<
typename
DataType
,
index_t
MPerBlock
,
index_t
KPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreXDramTileDistribution
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledKLdsWriteBlockDescriptor
()
{
constexpr
index_t
K1
=
16
/
sizeof
(
DataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
1
;
constexpr
index_t
M1
=
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
0
,
1
>>
{});
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackKT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreODramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KTLdsReadBlockDescriptor
()
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
auto
shuffled_k_lds_block_desc
=
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>
();
return
MakePreXDramTileDistribution
<
ODataType
,
kBlockSize
,
kKPerBlock
>
();
return
transform_tensor_descriptor
(
shuffled_k_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreOGradDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KTRegBlockDescriptor
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{})
;
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{})
;
return
MakePreXDramTileDistribution
<
OGradDataType
,
kBlockSize
,
kKPerBlock
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
kt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
kt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
kt_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
kt_block_dstr
=
make_static_tile_distribution
(
kt_block_dstr_encode
);
return
kt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeQ
TDramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
MakeQ
LdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledQTReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QRegSlice
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
N1
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
q_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
q_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
q_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
q_block_dstr
=
make_static_tile_distribution
(
q_block_dstr_encode
);
return
q_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
KTDramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
Make
ShuffledQRegWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled
KTReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled
QLdsWrite
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackQT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQTLdsReadBlockDescriptor
()
{
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
auto
shuffled_q_lds_block_desc
=
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>
();
return
transform_tensor_descriptor
(
shuffled_q_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
qt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
qt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
qt_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
qt_block_dstr
=
make_static_tile_distribution
(
qt_block_dstr_encode
);
return
qt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSGradTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
dst_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
dst_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
dst_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
dst_block_dstr
=
make_static_tile_distribution
(
dst_block_dstr_encode
);
return
dst_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEDLdsWriteBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
using
LSEDType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
constexpr
index_t
kMPack
=
16
/
sizeof
(
LSEDType
);
constexpr
auto
lsed_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
>
{}),
make_tuple
(
number
<
1
>
{}),
number
<
kMPack
>
{},
number
<
1
>
{});
return
lsed_lds_block_desc
;
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEDLdsReadBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
;
constexpr
index_t
N0
=
NWarp
;
// M4 *2 and M2 /2 when swizzle mode enabled
constexpr
index_t
SwizzleConfig
=
WG
::
kM
==
16
?
1
:
2
;
// constexpr index_t SwizzleConfig = 1;
constexpr
index_t
M4
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
*
SwizzleConfig
;
constexpr
index_t
M3
=
WG
::
WarpGemmAttribute
::
Impl
::
kCMLane
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM0PerLane
/
SwizzleConfig
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
WG
::
WarpGemmAttribute
::
Impl
::
kM
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
,
M4
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
3
,
1
>>
,
sequence
<
1
,
1
,
1
>
,
sequence
<
0
,
2
,
4
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradLdsBlockDescriptor
()
{
// Hold full block data
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
do_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
do_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
do_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
do_block_dstr
=
make_static_tile_distribution
(
do_block_dstr_encode
);
return
do_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
OGradTDramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
Make
ShuffledOGradRegWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledOGrad
TReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledOGrad
LdsWrite
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
kKPackT
=
GetSmemKPackOGradT
<
Problem
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTLdsReadBlockDescriptor
()
{
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
auto
shuffled_do_lds_block_desc
=
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>
();
return
transform_tensor_descriptor
(
shuffled_do_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
// constexpr index_t kNPerBlock = 32;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
dot_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
dot_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
dot_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
dot_block_dstr
=
make_static_tile_distribution
(
dot_block_dstr_encode
);
return
dot_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
pt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
pt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
pt_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
pt_block_dstr
=
make_static_tile_distribution
(
pt_block_dstr_encode
);
return
pt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
BiasTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
Make
SGradLdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackSGrad
<
Problem
>
();
constexpr
index_t
N1
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kMPerBlock
==
M0
*
M1
*
M2
*
M3
);
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSGradRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK4
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
ds_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
3
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
ds_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
ds_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
ds_block_dstr
=
make_static_tile_distribution
(
ds_block_dstr_encode
);
return
ds_block_dstr
;
}
template
<
typename
Problem
,
typename
PTOutTensor
,
typename
PInTensor
>
CK_TILE_DEVICE
static
constexpr
void
PTFromGemm0CToGemm1A
(
PTOutTensor
&
pt_out
,
const
PInTensor
&
p_in
)
{
if
constexpr
(
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{})
==
16
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
auto
pt_warp_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
GemmDataType
>
(
CWarpDstr
{});
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
pt_warp_tensor
.
get_thread_buffer
()
=
p_in
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
kIter
,
mIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
pt_out
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
),
pt_warp_tensor
.
get_thread_buffer
());
});
});
}
else
{
pt_out
.
get_thread_buffer
()
=
p_in
.
get_thread_buffer
();
}
}
template
<
typename
Problem
,
typename
SGradTOutTensor
,
typename
SGradInTensor
>
CK_TILE_DEVICE
static
constexpr
void
SGradTFromGemm2CToGemm3A
(
SGradTOutTensor
&
dst_out
,
const
SGradInTensor
&
ds_in
)
{
if
constexpr
(
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{})
==
16
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
auto
dst_warp_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
GemmDataType
>
(
CWarpDstr
{});
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
dst_warp_tensor
.
get_thread_buffer
()
=
ds_in
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
kIter
,
mIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
dst_out
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
),
dst_warp_tensor
.
get_thread_buffer
());
});
});
}
else
{
dst_out
.
get_thread_buffer
()
=
ds_in
.
get_thread_buffer
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetTransposed
AlignmentBias
<
Problem
>
();
constexpr
index_t
N1
=
Get
AlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M2
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
M1
=
get_warp_size
()
/
N0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
3
>>
{});
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasLdsBlockDescriptor
()
{
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackBiasT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kMPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBias
T
TileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBias
S
TileDistribution
()
{
using
c_block_tensor_type
=
decltype
(
BlockGemm
{}.
MakeCBlockTile
());
return
c_block_tensor_type
::
get_tile_distribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeQ
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
constexpr
index_t
smem_size_q
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_q
;
}
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeQT
()
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
{};
constexpr
index_t
smem_size_qt
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_qt
;
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeK
(
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
{};
constexpr
index_t
smem_size_k
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_k
;
}
}();
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeKT
()
{
constexpr
index_t
smem_size_kt
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKTLdsReadBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_kt
;
}
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeLSE
()
{
constexpr
index_t
smem_size_lse
=
sizeof
(
typename
Problem
::
LSEDataType
)
*
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_lse
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeD
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
constexpr
index_t
smem_size_d
=
sizeof
(
typename
Problem
::
DDataType
)
*
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_d
;
}
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeV
()
{
constexpr
index_t
smem_size_v
=
sizeof
(
typename
Problem
::
VDataType
)
*
MakeVLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_v
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeOGrad
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>>
;
constexpr
index_t
smem_size_do
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_do
;
}
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeOGradT
()
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
{};
constexpr
index_t
smem_size_dot
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_dot
;
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeSGrad
(
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
{};
constexpr
index_t
smem_size_ds
=
sizeof
(
typename
Problem
::
GemmDataType
)
*
MakeSGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_ds
;
}
}();
using
BlockGemmPolicy
=
BlockGemmASmemBRegCRegV1CustomPolicy
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
// {
// using BlockGemmProblem =
// BlockGemmPipelineProblem<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// Problem::kBlockSize,
// TileGemmShape<Problem::BlockFmhaShape::kM0,
// Problem::BlockFmhaShape::kN0,
// Problem::BlockFmhaShape::kK2>>;
// constexpr auto warp_gemm = []() {
// if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
// std::is_same_v<typename Problem::VDataType, half_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
// }
// else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
// std::is_same_v<typename Problem::VDataType, bf16_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
// }
// }();
// using BlockGemmPolicy =
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// typename
// Problem::BlockFmhaShape::Gemm2BlockWarps,
// decltype(warp_gemm)>;
// return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
// }
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeBias
()
{
constexpr
index_t
smem_size_bias
=
[
&
]()
{
if
constexpr
(
Problem
::
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
sizeof
(
typename
Problem
::
BiasDataType
)
*
MakeBiasLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
else
return
0
;
}();
return
smem_size_bias
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>>
;
constexpr
index_t
smem_size_q
=
GetSmemSizeQ
<
Problem
>
();
constexpr
index_t
smem_size_qt
=
GetSmemSizeQT
<
Problem
>
();
constexpr
index_t
smem_size_lse
=
GetSmemSizeLSE
<
Problem
>
();
constexpr
index_t
smem_size_k
=
GetSmemSizeK
<
Problem
>
();
constexpr
index_t
smem_size_kt
=
GetSmemSizeKT
<
Problem
>
();
constexpr
index_t
smem_size_v
=
GetSmemSizeV
<
Problem
>
();
constexpr
index_t
smem_size_do
=
GetSmemSizeOGrad
<
Problem
>
();
constexpr
index_t
smem_size_dot
=
GetSmemSizeOGradT
<
Problem
>
();
constexpr
index_t
smem_size_d
=
GetSmemSizeD
<
Problem
>
();
constexpr
index_t
smem_size_ds
=
GetSmemSizeSGrad
<
Problem
>
();
constexpr
index_t
smem_size_bias
=
GetSmemSizeBias
<
Problem
>
();
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
constexpr
index_t
smem_size_stage0_0
=
smem_size_k
+
smem_size_kt
;
constexpr
index_t
smem_size_stage0_1
=
smem_size_v
;
constexpr
index_t
smem_size_stage1
=
smem_size_qt
+
smem_size_q
+
+
smem_size_dot
+
smem_size_do
+
smem_size_lse
+
smem_size_d
+
max
(
smem_size_bias
,
smem_size_ds
);
return
max
(
smem_size_stage0_0
,
smem_size_stage0_1
,
smem_size_stage1
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
template
<
typename
Problem
_
>
struct
HotLoopScheduler
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>>
;
using
Problem
=
Problem_
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
index_t
GemmStage
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
()
{
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
0
>
()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
constexpr
index_t
VMEM_READ_INST
=
Q_VMEM_READ
+
OGrad_VMEM_READ
+
LSE_VMEM_READ
+
D_VMEM_READ
;
constexpr
index_t
LDS_READ_INST
=
OGradT_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm0MFMA
;
// Evenly distributed to relieve SQ->TA FIFO pressure
constexpr
index_t
MFMA_PER_VMEM_READ
=
MFMA_INST
/
VMEM_READ_INST
;
constexpr
index_t
MFMA_Remainder
=
MFMA_INST
-
MFMA_PER_VMEM_READ
*
VMEM_READ_INST
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
static_for
<
0
,
VMEM_READ_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
static_for
<
0
,
MFMA_PER_VMEM_READ
,
1
>
{}([
&
](
auto
j
)
{
ignore
=
j
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
});
static_for
<
0
,
MFMA_Remainder
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
1
>
()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
constexpr
index_t
LDS_READ_INST
=
QT_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm1MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
2
>
()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
constexpr
index_t
LDS_WRITE_INST
=
Q_LDS_WRITE
+
QT_LDS_WRITE
+
OGrad_LDS_WRITE
+
OGradT_LDS_WRITE
+
LSE_LDS_WRITE
+
D_LDS_WRITE
;
constexpr
index_t
MFMA_INST
=
Gemm2MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS write
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
3
>
()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr
index_t
LDS_WRITE_INST
=
SGradT_LDS_WRITE
;
constexpr
index_t
LDS_READ_INST
=
SGradT_LDS_READ_P1
+
Q_LDS_READ
+
LSE_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm3MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
>=
1
?
LDS_WRITE_INST
/
MFMA_INST
:
1
;
constexpr
index_t
MFMA_INST_LDS_WRITE
=
LDS_WRITE_INST
/
LDS_WRITE_PER_MFMA
;
constexpr
index_t
LDS_READ_PER_MFMA
=
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
>
0
?
LDS_READ_INST
/
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
>
0
?
LDS_READ_INST
/
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
:
1
:
0
;
static_for
<
0
,
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS Write
});
static_for
<
0
,
MFMA_INST
-
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
4
>
()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr
index_t
LDS_READ_INST
=
SGradT_LDS_READ_P2
+
OGrad_LDS_READ
+
D_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm4MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
>
0
?
LDS_READ_INST
/
MFMA_INST
:
1
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
});
}
private:
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
Problem
::
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpGemmN
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpGemmK
=
WarpGemmM
==
16
?
16
:
8
;
static
constexpr
index_t
Gemm4MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Gemm4NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
// Compute
static
constexpr
index_t
Gemm0MFMA
=
kM0
*
kN0
*
kQKHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm1MFMA
=
kM0
*
kN0
*
kVHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kN0
*
kVHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm3MFMA
=
kN0
*
kQKHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm4MFMA
=
kM0
*
kQKHeaddim
*
kN0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
// VMEM
static
constexpr
index_t
Q_VMEM_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
OGrad_VMEM_READ
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
LSE_VMEM_READ
=
1
;
static
constexpr
index_t
D_VMEM_READ
=
1
;
// LDS Read
static
constexpr
index_t
OGradT_LDS_READ
=
kM0
*
kVHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
QT_LDS_READ
=
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
SGradT_LDS_READ_P1
=
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
SGradT_LDS_READ_P2
=
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
static
constexpr
index_t
Q_LDS_WRITE
=
kM0
*
kQKHeaddim
/
Problem
::
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
QT_LDS_WRITE
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_WRITE
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
OGradT_LDS_WRITE
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetTransposedAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_WRITE
=
1
;
static
constexpr
index_t
D_LDS_WRITE
=
1
;
static
constexpr
index_t
SGradT_LDS_WRITE
=
kM0
*
kN0
/
kBlockSize
;
};
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
View file @
f84e2020
...
...
@@ -8,9 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum
class
BlockFmhaBwdPipelineEnum
{
KSKTSVR
=
0
,
QSKSVROGradS
,
KSVR
,
KRKTRVR_IGLP
=
0
,
KRKTRVR
,
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
View file @
f84e2020
...
...
@@ -24,7 +24,9 @@ template <typename QDataType_,
typename
BiasGradDataType_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
typename
Traits_
>
struct
BlockFmhaBwdPipelineProblem
{
...
...
@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using
BiasGradDataType
=
remove_cvref_t
<
BiasGradDataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
remove_cvref_t
<
FmhaDropout_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
...
@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Traits
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
...
...
@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
template
<
typename
AccDataType_
,
typename
QGradDataType_
,
index_t
kBlockSize_
,
index_t
kM0_
,
index_t
kN0_
,
index_t
kQKHeaddim_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
,
typename
Traits_
>
struct
BlockFmhaBwdConvertQGradPipelineProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
QGradDataType
=
remove_cvref_t
<
QGradDataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static_assert
(
0
<
kBlockSize_
&&
kBlockSize_
%
get_warp_size
()
==
0
,
"kBlockSize should be divisible by get_warp_size()"
);
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
index_t
kN0
=
kN0_
;
static
constexpr
index_t
kQKHeaddim
=
kQKHeaddim_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
f84e2020
...
...
@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
struct
TileFmhaBwdConvertQGradTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
f84e2020
...
...
@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
0 → 100644
View file @
f84e2020
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
// M->N Warp
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
// check ABC-block-distribution
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
a_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
ABlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"A distribution is wrong!"
);
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
b_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
BBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"B distribution is wrong!"
);
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
BWarpDstr
=
typename
WG
::
BWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
BWarpTensor
=
typename
WG
::
BWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A Block window
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B block tensor
BWarpTensor
b_warp_tensor
;
b_warp_tensor
.
get_thread_buffer
()
=
b_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
nIter
,
kIter
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor
,
b_block_tensor
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment