Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
da7ef7e8
Unverified
Commit
da7ef7e8
authored
Jul 08, 2024
by
Jun Liu
Committed by
GitHub
Jul 08, 2024
Browse files
Merge pull request #88 from ROCm/merge_from_public
Merge from public
parents
9f2a6d43
dcd3d21a
Changes
379
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4530 additions
and
98 deletions
+4530
-98
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
+20
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp
...ipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp
+692
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
+20
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+1343
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
...k_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
+16
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
...ile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
+91
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+314
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+175
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+666
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
...peline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+770
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
...ha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
+19
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
+19
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+81
-13
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+39
-17
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+77
-27
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+25
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+15
-14
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+40
-6
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+49
-0
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+59
-1
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
using
KGradDataType
=
remove_cvref_t
<
typename
Problem
::
KGradDataType
>
;
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK2
=
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK3
=
BlockFmhaShape
::
kK3
;
static
constexpr
index_t
kK4
=
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
true
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
false
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
true
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentVGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"qs_ks_vr_dos"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
typename
BiasGradDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
/*qt_dram_block_window_tmp*/
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
/*kt_dram_block_window_tmp*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
/*dot_dram_block_window_tmp*/
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
const
BiasGradDramBlockWindowTmp
&
dbias_dram_block_window_tmp
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// QT tile in LDS
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptorAsQT
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptorAsKT
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
{
0
,
0
});
// OGradT tile in LDS
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptorAsOGradT
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
constexpr
auto
gemm_2
=
Policy
::
template
GetOGradVBlockGemm
<
Problem
>();
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_q_end
-
seqlen_q_start
,
kM0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
}
auto
k_block_tile
=
load_tile
(
k_dram_window
);
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
auto
q_dram_block_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
do_dram_block_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
lse_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
auto
q_block_tile
=
load_tile
(
q_dram_window
);
clear_tile
(
st_acc
);
// Initialize S^T
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
get_slice_tile
(
q_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
});
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
get_slice_tile
(
q_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
block_sync_lds
();
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
store_tile
(
biast_lds_shuffle_window
,
bias_shuffle_tmp
);
block_sync_lds
();
auto
biast_tile
=
load_tile
(
biast_lds_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
st_acc
,
biast_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
st_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
st_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
auto
pt
=
SPTBlockTileType
{};
constexpr
auto
pt_spans
=
decltype
(
pt
)
::
get_distributed_spans
();
sweep_tile_span
(
pt_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
pt
(
i_j_idx
)
=
exp2
(
st_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
pt
(
i_j_idx
)
=
exp2
(
scale
*
st_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_start
+
i_total_loops
*
kM0
,
pt
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
// store the prefetch
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
pt
);
}
}();
static_for
<
0
,
k1_loops
,
1
>
{}([
&
](
auto
i_k1
)
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
get_slice_tile
(
dot_lds_window
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kVHeaddim
,
(
i_k1
+
1
)
*
kK1
>
{}));
block_sync_lds
();
});
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
static_for
<
0
,
k2_loops
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
get_slice_tile
(
do_lds_window
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kM0
,
(
i_k2
+
1
)
*
kK2
>
{}),
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
});
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
dst_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
pt
[
i_j_idx
]
>=
0
;
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbiast
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
dst
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
dst
);
}
}();
store_tile
(
biast_lds_shuffle_window
,
dbiast
);
block_sync_lds
();
auto
dbiast_tile
=
load_tile
(
dbiast_lds_shuffle_window
);
auto
dbiast_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
}
// STAGE 6, SGrad^T@Q^T Gemm3
block_sync_lds
();
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
static_for
<
0
,
k3_loops
,
1
>
{}([
&
](
auto
i_k3
)
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
get_slice_tile
(
qt_lds_window
,
sequence
<
0
,
i_k3
*
kK3
>
{},
sequence
<
kQKHeaddim
,
(
i_k3
+
1
)
*
kK3
>
{}));
block_sync_lds
();
});
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
block_sync_lds
();
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
});
// QGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
// KGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, q & k & do located in lds.
using
BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
true
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
true
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace
ck_tile
{
template
<
bool
QLoadOnce_
,
bool
QTLoadOnce_
,
bool
KLoadOnce_
,
bool
KTLoadOnce_
,
bool
VLoadOnce_
,
bool
OGradLoadOnce_
,
bool
OGradTLoadOnce_
>
struct
BlockFmhaBwdPipelineDefaultPolicy
{
static
constexpr
bool
QLoadOnce
=
QLoadOnce_
;
// if q load whole block length (qkhdim) to LDS at once
static
constexpr
bool
QTLoadOnce
=
QTLoadOnce_
;
// if q^t load whole block length (qkhdim) to LDS at once
static
constexpr
bool
KLoadOnce
=
KLoadOnce_
;
// if k load whole block length (qkhdim) to LDS at once
static
constexpr
bool
KTLoadOnce
=
KTLoadOnce_
;
// if k^t load whole block length (qkhdim) to LDS at once
static
constexpr
bool
VLoadOnce
=
VLoadOnce_
;
// if v load whole block length (vhdim) to Vgprs at once
static
constexpr
bool
OGradLoadOnce
=
OGradLoadOnce_
;
// if do load whole block length (vhdim) to LDS at once
static
constexpr
bool
OGradTLoadOnce
=
OGradTLoadOnce_
;
// if do^t load whole block length (vhdim) to LDS at once
// these are for global load
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentV
()
{
if
constexpr
(
VLoadOnce
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
}
else
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentO
()
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
return
16
/
sizeof
(
ODataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOGrad
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQGrad
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentKGrad
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentVGrad
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentQ
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentK
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentOGrad
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentBias
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
32
)
return
8
;
else
return
4
;
}
// these are for lds
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
// TODO: this is for 3d layout
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackK
()
{
// TODO: this is for 3d layout
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackV
()
{
// TODO: this is for 3d layout
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
// TODO: this is for 3d layout
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
return
16
/
sizeof
(
BiasDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
{
// TODO: this is for 3d layout
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackSGrad
()
{
// TODO: this is for 3d layout
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
return
16
/
sizeof
(
GemmDataType
);
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVInRegDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
// 3d + padding
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
MNPerBlock
),
make_merge_transform
(
make_tuple
(
KPerBlock
/
KPack
,
KPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
x_lds_block_desc
;
}
// 3d + padding
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptorAsXT
()
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
MNPerBlock
),
make_merge_transform
(
make_tuple
(
KPerBlock
/
KPack
,
KPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
xt_lds_block_desc
;
}
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
,
index_t
PixelsPerRow
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXTLdsBlockDescriptor
()
{
static_assert
(
PixelsPerRow
%
KPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
KPack
;
static_assert
(
MNPerBlock
%
NPerRow
==
0
);
static_assert
(
KPerBlock
%
KPack
==
0
);
constexpr
auto
xt_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
KPack
)
>
{},
number
<
PixelsPerRow
+
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
xt_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
MNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
xt_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptorAsQT
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
return
MakeXLdsBlockDescriptorAsXT
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptorAsKT
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
MakeXLdsBlockDescriptorAsXT
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradLdsBlockDescriptorAsOGradT
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
return
MakeXLdsBlockDescriptorAsXT
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSGradLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackSGrad
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQTLdsBlockDescriptor
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
QDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKTLdsBlockDescriptor
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
KDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTLdsBlockDescriptor
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTLdsBlockDescriptor
()
{
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kMPerBlock
%
kKPack
==
0
);
constexpr
auto
biast_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
biast_lds_block_desc
=
transform_tensor_descriptor
(
biast_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
biast_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
constexpr
index_t
smem_size_q
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_q
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQT
()
{
constexpr
index_t
smem_size_qt
=
[
&
]()
{
if
constexpr
(
QLoadOnce
&&
!
QTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_qt
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
{
constexpr
index_t
smem_size_k
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_k
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeKT
()
{
constexpr
index_t
smem_size_kt
=
[
&
]()
{
if
constexpr
(
KLoadOnce
&&
!
KTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_kt
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeV
()
{
constexpr
index_t
smem_size_v
=
[
&
]()
{
if
constexpr
(
VLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
VDataType
)
*
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_v
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeOGrad
()
{
constexpr
index_t
smem_size_do
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_do
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeOGradT
()
{
constexpr
index_t
smem_size_dot
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
&&
!
OGradTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_dot
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeSGrad
()
{
constexpr
index_t
smem_size_ds
=
sizeof
(
typename
Problem
::
GemmDataType
)
*
MakeSGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_ds
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeBias
()
{
constexpr
index_t
smem_size_bias
=
[
&
]()
{
if
constexpr
(
Problem
::
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
sizeof
(
typename
Problem
::
BiasDataType
)
*
MakeBiasTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
else
return
0
;
}();
return
smem_size_bias
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_q
=
GetSmemSizeQ
<
Problem
>
();
constexpr
index_t
smem_size_qt
=
GetSmemSizeQT
<
Problem
>
();
constexpr
index_t
smem_size_k
=
GetSmemSizeK
<
Problem
>
();
constexpr
index_t
smem_size_kt
=
GetSmemSizeKT
<
Problem
>
();
constexpr
index_t
smem_size_v
=
GetSmemSizeV
<
Problem
>
();
constexpr
index_t
smem_size_do
=
GetSmemSizeOGrad
<
Problem
>
();
constexpr
index_t
smem_size_dot
=
GetSmemSizeOGradT
<
Problem
>
();
constexpr
index_t
smem_size_ds
=
GetSmemSizeSGrad
<
Problem
>
();
constexpr
index_t
smem_size_bias
=
GetSmemSizeBias
<
Problem
>
();
constexpr
index_t
smem_size_transpose
=
max
(
smem_size_ds
,
smem_size_bias
);
index_t
smem_size
=
0
;
if
constexpr
(
QLoadOnce
&&
OGradLoadOnce
)
smem_size
+=
smem_size_q
+
smem_size_qt
+
smem_size_do
+
smem_size_dot
+
smem_size_transpose
;
// 1~4 & 10
else
if
(
QLoadOnce
&&
!
OGradLoadOnce
&&
!
OGradTLoadOnce
)
smem_size
+=
smem_size_q
+
smem_size_qt
+
max
(
smem_size_do
,
smem_size_dot
,
smem_size_transpose
);
// 5/7/11 TODO: Multiple buffers strategy
else
if
(
!
QLoadOnce
&&
!
QTLoadOnce
&&
OGradLoadOnce
)
smem_size
+=
smem_size_do
+
smem_size_dot
+
max
(
smem_size_q
,
smem_size_qt
,
smem_size_transpose
);
// 6/8/12 TODO: Multiple buffers strategy
else
if
(
!
QLoadOnce
&&
!
QTLoadOnce
&&
!
OGradLoadOnce
&&
!
OGradTLoadOnce
)
smem_size
+=
max
(
smem_size_q
,
smem_size_qt
,
smem_size_do
,
smem_size_dot
,
smem_size_transpose
);
// 9/13 TODO: Multiple buffers strategy
// 14/15 needs to be adjusted
if
constexpr
(
KLoadOnce
)
smem_size
+=
(
smem_size_k
+
smem_size_kt
);
// 1~13
else
smem_size
=
max
(
smem_size_k
,
smem_size_kt
,
smem_size
);
// 14/15 TODO: Multiple buffers strategy
return
max
(
smem_size
,
smem_size_v
);
// 15
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEDDramTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
;
constexpr
index_t
N0
=
NWarp
;
constexpr
index_t
M4
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
*
2
;
constexpr
index_t
M3
=
WG
::
WarpGemmAttribute
::
Impl
::
kCMLane
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM0PerLane
/
2
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
WG
::
WarpGemmAttribute
::
Impl
::
kM
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
,
M4
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
3
,
1
>>
,
sequence
<
1
,
1
,
1
>
,
sequence
<
0
,
2
,
4
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
K1
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
DataType
,
index_t
MPerBlock
,
index_t
KPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePreXDramTileDistribution
()
{
constexpr
index_t
K1
=
16
/
sizeof
(
DataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
1
;
constexpr
index_t
M1
=
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePreODramTileDistribution
()
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
return
MakePreXDramTileDistribution
<
ODataType
,
kBlockSize
,
kKPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePreOGradDramTileDistribution
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
return
MakePreXDramTileDistribution
<
OGradDataType
,
kBlockSize
,
kKPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeQTDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledQTRegBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeKTDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledKTRegBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeOGradTDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledOGradTRegBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kMPerBlock
==
M0
*
M1
*
M2
*
M3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
3
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
3
>>
{});
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTTileDistribution
()
{
using
c_block_tensor_type
=
decltype
(
BlockGemm
{}.
MakeCBlockTile
());
return
c_block_tensor_type
::
get_tile_distribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
{};
}
}();
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
{};
}
}();
using
BlockGemmPolicy
=
BlockGemmASmemBRegCRegV1CustomPolicy
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
// {
// using BlockGemmProblem =
// BlockGemmPipelineProblem<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// Problem::kBlockSize,
// TileGemmShape<Problem::BlockFmhaShape::kM0,
// Problem::BlockFmhaShape::kN0,
// Problem::BlockFmhaShape::kK2>>;
// constexpr auto warp_gemm = []() {
// if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
// std::is_same_v<typename Problem::VDataType, half_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
// }
// else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
// std::is_same_v<typename Problem::VDataType, bf16_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
// }
// }();
// using BlockGemmPolicy =
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// typename
// Problem::BlockFmhaShape::Gemm2BlockWarps,
// decltype(warp_gemm)>;
// return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
// }
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
// This class is used for codegen pattern matching
enum
class
BlockFmhaBwdPipelineEnum
{
KSKTSVR
=
0
,
QSKSVROGradS
,
KSVR
,
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
QDataType_
,
typename
KDataType_
,
typename
VDataType_
,
typename
GemmDataType_
,
typename
LSEDataType_
,
typename
AccDataType_
,
typename
DDataType_
,
typename
BiasDataType_
,
typename
RandValOutputDataType_
,
typename
ODataType_
,
typename
OGradDataType_
,
typename
QGradDataType_
,
typename
KGradDataType_
,
typename
VGradDataType_
,
typename
BiasGradDataType_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
typename
FmhaMask_
,
typename
Traits_
>
struct
BlockFmhaBwdPipelineProblem
{
using
QDataType
=
remove_cvref_t
<
QDataType_
>
;
using
KDataType
=
remove_cvref_t
<
KDataType_
>
;
using
VDataType
=
remove_cvref_t
<
VDataType_
>
;
using
GemmDataType
=
remove_cvref_t
<
GemmDataType_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
DDataType
=
remove_cvref_t
<
DDataType_
>
;
using
BiasDataType
=
remove_cvref_t
<
BiasDataType_
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
RandValOutputDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
OGradDataType
=
remove_cvref_t
<
OGradDataType_
>
;
using
QGradDataType
=
remove_cvref_t
<
QGradDataType_
>
;
using
KGradDataType
=
remove_cvref_t
<
KGradDataType_
>
;
using
VGradDataType
=
remove_cvref_t
<
VGradDataType_
>
;
using
BiasGradDataType
=
remove_cvref_t
<
BiasGradDataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Traits
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
template
<
typename
ODataType_
,
typename
OGradDataType_
,
typename
DDataType_
,
index_t
kBlockSize_
,
index_t
kVHeaddim_
,
bool
kIsGroupMode_
,
typename
Traits_
>
struct
BlockFmhaBwdOGradDotOPipelineProblem
{
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
OGradDataType
=
remove_cvref_t
<
OGradDataType_
>
;
using
DDataType
=
remove_cvref_t
<
DDataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static_assert
(
0
<
kBlockSize_
&&
kBlockSize_
%
get_warp_size
()
==
0
,
"kBlockSize should be divisible by get_warp_size()"
);
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kVHeaddim
=
kVHeaddim_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
namespace
detail
{
template
<
index_t
N
>
struct
log2
;
template
<
>
struct
log2
<
16
>
:
std
::
integral_constant
<
index_t
,
4
>
{
};
template
<
>
struct
log2
<
32
>
:
std
::
integral_constant
<
index_t
,
5
>
{
};
template
<
>
struct
log2
<
64
>
:
std
::
integral_constant
<
index_t
,
6
>
{
};
template
<
>
struct
log2
<
128
>
:
std
::
integral_constant
<
index_t
,
7
>
{
};
}
// namespace detail
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
>
struct
BlockFmhaFwdSplitKVCombinePipeline
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kHeadDimV
=
Problem
::
kHeadDimV
;
static
constexpr
index_t
kM0
=
Problem
::
kM0
;
static
constexpr
index_t
kN1
=
Problem
::
kN1
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
index_t
kMaxSplits
=
Problem
::
kMaxSplits
;
static
constexpr
index_t
kAlignmentLSE
=
kPadSeqLenQ
?
1
:
Policy
::
template
GetAlignmentLSE
<
Problem
>();
static
constexpr
index_t
kAlignmentLSEacc
=
kAlignmentLSE
;
static
constexpr
index_t
kAlignmentOacc
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOacc
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kHeadDimV
<=
32
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
3
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
}
else
if
constexpr
(
kHeadDimV
<=
128
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
}
else
if
constexpr
(
kHeadDimV
<=
256
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
}
}
}();
static
constexpr
const
char
*
name
=
"unused"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
LSEaccDramBlockWindowTmp
,
typename
OaccDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
LSEElementFunction
,
typename
OaccElementFunction
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
const
OaccDramBlockWindowTmp
&
o_acc_dram_block_window_tmp
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
max_seqlen_q
,
void
*
smem_ptr
)
const
{
// lse_acc tile in LDS
LSEDataType
*
lse_acc_lds_ptr
=
static_cast
<
LSEDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
lse_acc_lds
=
[
=
,
lds_desc
=
Policy
::
template
MakeLSEaccLdsBlockDescriptor
<
Problem
>()](
index_t
row
,
index_t
col
)
->
LSEDataType
&
{
return
lse_acc_lds_ptr
[
lds_desc
.
calculate_offset
(
make_tuple
(
row
,
col
))];
};
auto
lse_acc_lds_write_window
=
[
&
]()
{
auto
view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
lse_acc_lds_ptr
,
Policy
::
template
MakeLSEaccLdsStoreBlockDescriptor
<
Problem
>());
return
make_tile_window
(
view
,
make_tuple
(
number
<
kMaxSplits
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
}();
auto
lse_acc_dram_window
=
make_tile_window
(
lse_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_acc_dram_block_window_tmp
.
get_window_lengths
(),
lse_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeLSEaccDramTileDistribution
<
Problem
>());
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto
lse_acc_tile
=
load_tile
(
lse_acc_dram_window
);
store_tile
(
lse_acc_lds_write_window
,
lse_acc_tile
);
block_sync_lds
();
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
{
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
lse_accum
.
get_tile_distribution
(),
i_j_idx
);
const
auto
col
=
x_indices
.
at
(
number
<
1
>
{});
if
(
col
<
num_splits
)
{
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
lse_accum
(
i_j_idx
)
=
lse_acc_lds
(
row
,
col
);
}
else
{
lse_accum
(
i_j_idx
)
=
-
numeric
<
LSEDataType
>::
infinity
();
}
});
});
}
// compute the logsumexp of the LSE along the split dimension.
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
ck_tile
::
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
auto
lse_max
=
block_tile_reduce
<
LSEDataType
>
(
lse_accum
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
LSEDataType
>::
infinity
());
block_tile_reduce_sync
(
lse_max
,
f_max
,
bool_constant
<
false
>
{});
static
const
auto
get_validated_m
=
[](
LSEDataType
raw_m
)
{
return
raw_m
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_m
;
};
decltype
(
lse_accum
)
lse_exp
;
{
constexpr
auto
spans
=
decltype
(
lse_exp
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
lse_exp
(
i_j_idx
)
=
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
get_validated_m
(
lse_max
(
i_idx
)));
});
});
}
auto
lse_sum
=
block_tile_reduce
<
LSEDataType
>
(
lse_exp
,
sequence
<
1
>
{},
f_sum
,
type_convert
<
LSEDataType
>
(
0
));
block_tile_reduce_sync
(
lse_sum
,
f_sum
,
bool_constant
<
false
>
{});
decltype
(
lse_max
)
lse_logsum
;
{
constexpr
auto
spans
=
decltype
(
lse_logsum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
if
(
lse_sum
(
i_idx
)
==
0.
f
||
lse_sum
(
i_idx
)
!=
lse_sum
(
i_idx
))
{
lse_logsum
(
i_idx
)
=
numeric
<
LSEDataType
>::
infinity
();
}
else
{
lse_logsum
(
i_idx
)
=
ck_tile
::
log
(
lse_sum
(
i_idx
))
+
get_validated_m
(
lse_max
(
i_idx
));
}
});
}
// store the lse scales in shared memory.
{
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
lse_accum
.
get_tile_distribution
(),
i_j_idx
);
const
auto
col
=
x_indices
.
at
(
number
<
1
>
{});
if
(
col
<
num_splits
)
{
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
lse_acc_lds
(
row
,
col
)
=
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
lse_logsum
(
i_idx
));
}
});
});
}
block_sync_lds
();
if
constexpr
(
kStoreLSE
)
{
constexpr
auto
spans
=
decltype
(
lse_logsum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
if
(
lse_logsum
(
i_idx
)
==
numeric
<
LSEDataType
>::
infinity
())
{
lse_logsum
(
i_idx
)
=
-
numeric
<
LSEDataType
>::
infinity
();
}
});
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse_logsum
));
}
auto
o_acc_dist
=
Policy
::
template
MakeOaccDramTileDistribution
<
Problem
>();
auto
o_acc_dram_window
=
make_tile_window
(
o_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
o_acc_dram_block_window_tmp
.
get_window_lengths
(),
o_acc_dram_block_window_tmp
.
get_window_origin
(),
o_acc_dist
);
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
const
index_t
padded_max_seqlen_q
=
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
kM0
;
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
{
auto
o_tile
=
load_tile
(
o_acc_dram_window
);
{
constexpr
auto
spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
o_acc
.
get_tile_distribution
(),
i_j_idx
);
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
const
LSEDataType
lse_scale
=
lse_acc_lds
(
row
,
i_split
);
o_acc
(
i_j_idx
)
+=
lse_scale
*
o_tile
(
i_j_idx
);
});
});
}
move_tile_window
(
o_acc_dram_window
,
{
padded_max_seqlen_q
,
0
});
}
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
LSEaccDramBlockWindow
,
typename
OaccDramBlockWindow
,
typename
LSEDramBlockWindow
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
LSEaccDramBlockWindow
&
lse_acc_dram_block_window
,
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
max_seqlen_q
,
void
*
smem_ptr
)
const
{
return
operator
()(
lse_acc_dram_block_window
,
o_acc_dram_block_window
,
lse_dram_block_window
,
identity
{},
identity
{},
num_splits
,
max_seqlen_q
,
smem_ptr
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace
ck_tile
{
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
return
16
/
sizeof
(
LSEDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOacc
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
return
16
/
sizeof
(
OaccDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentO
()
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
return
16
/
sizeof
(
ODataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
sizeof
(
typename
Problem
::
LSEDataType
)
*
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPerThread
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
TotalWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
TotalWarps
*
MThreadsPerWarp
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MPerThread
*
TotalWarps
*
MThreadsPerWarp
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
TotalWarps
,
MThreadsPerWarp
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// 3d + padding, [kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsStoreBlockDescriptor
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
lse_acc_lds_block_desc
=
transform_tensor_descriptor
(
lse_acc_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kNPerBlock
/
NPack
,
NPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lse_acc_lds_block_desc
;
}
// 3d + padding, [kM0, kMaxSplits]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsBlockDescriptor
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
lse_acc_t_lds_block_desc
=
transform_tensor_descriptor
(
lse_acc_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kNPerBlock
/
NPack
,
NPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
lse_acc_t_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccRegTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
max
(
Problem
::
kMaxSplits
,
get_warp_size
());
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NThreads
=
get_warp_size
();
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MThreads
*
MPerThread
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MThreads
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOaccDramTileDistribution
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
N1
=
16
/
sizeof
(
OaccDataType
);
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M2
=
get_warp_size
()
/
N0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineQRKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kHasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEaccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
q
=
load_tile
(
q_dram_window
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
{
if
(
num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return
o_acc
;
}
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
do
{
// STAGE 1, QK gemm
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
clear_tile
(
s_acc
);
// initialize C
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
k_block_tile
=
load_tile
(
k_dram_window
);
}
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
// LDS write i + 1
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kM0
,
(
k0_loops
-
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window
);
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
seqlen_k_end_
<=
col
;
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_prefetch
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window
);
// load next v
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
v_lds_window
);
block_sync_lds
();
}
}
while
(
++
i_total_loops
<
num_total_loop
);
if
constexpr
(
kStoreLSE
)
{
// store lse acc
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_acc_spans
=
decltype
(
lse_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_acc_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_acc_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
i_split
,
mask
,
position_encoding
,
scale_s
,
smem_ptr
,
dropout
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
Problem
::
kPadHeadDimV
==
true
);
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kHasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static
constexpr
auto
R_LOG2E
=
1.0
/
log2e_v
<
SaccDataType
>
;
#endif
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr_async"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEaccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
/*k_element_func*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
constexpr
auto
LdsSeq
=
Policy
::
template
GetLdsBufferSequence
<
Problem
>();
// K tile in LDS
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_store
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto
k_lds_load
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#else
auto
k_lds_Load_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>());
auto
k_lds_load
=
make_tile_window
(
k_lds_Load_view
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
#endif
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
if
(
num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return
o_acc
;
}
__builtin_amdgcn_sched_barrier
(
0
);
// make sure sched_barrier(0) for this check
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
i_k0
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
auto
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
}
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
seqlen_k_end_
<=
col
;
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
__builtin_amdgcn_sched_barrier
(
0x7F
);
// store & prefetch next v, after the max reduction
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store the prefetch
}
if
constexpr
(
k1_loops
>
1
)
{
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
__builtin_amdgcn_sched_barrier
(
0
);
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
if
constexpr
(
kHasDropout
)
{
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
randval_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
{
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
}
if
constexpr
(
i_k1
<
k1_loops
-
1
)
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
i_total_loops
++
;
if
(
i_total_loops
<
num_total_loop
)
{
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
}
}
while
(
i_total_loops
<
num_total_loop
);
// store lse acc
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_acc_spans
=
decltype
(
lse_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_acc_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_acc_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
i_split
,
mask
,
position_encoding
,
scale_s
,
smem_ptr
,
dropout
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
true
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
3
,
/* NumPrefetchV = */
3
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
da7ef7e8
...
...
@@ -13,6 +13,7 @@ template <typename QDataType_,
typename
SaccDataType_
,
typename
SMPLComputeDataType_
,
typename
BiasDataType_
,
typename
RandValOutputDataType_
,
typename
LSEDataType_
,
typename
PDataType_
,
typename
OaccDataType_
,
...
...
@@ -29,6 +30,7 @@ struct BlockFmhaPipelineProblem
using
SaccDataType
=
remove_cvref_t
<
SaccDataType_
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
SMPLComputeDataType_
>
;
using
BiasDataType
=
remove_cvref_t
<
BiasDataType_
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
RandValOutputDataType_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
PDataType
=
remove_cvref_t
<
PDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
...
...
@@ -47,8 +49,74 @@ struct BlockFmhaPipelineProblem
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
template
<
typename
QDataType
,
typename
KDataType
,
typename
VDataType
,
typename
SaccDataType
,
typename
SMPLComputeDataType
,
typename
BiasDataType
,
typename
RandValOutputDataType
,
typename
LSEDataType
,
typename
PDataType
,
typename
OaccDataType
,
typename
ODataType
,
typename
BlockFmhaShape
,
bool
kIsGroupMode
,
typename
FmhaMask
,
typename
Traits
>
struct
BlockFmhaFwdSplitKVPipelineProblem
:
BlockFmhaPipelineProblem
<
QDataType
,
KDataType
,
VDataType
,
SaccDataType
,
SMPLComputeDataType
,
BiasDataType
,
RandValOutputDataType
,
LSEDataType
,
PDataType
,
OaccDataType
,
ODataType
,
BlockFmhaShape
,
kIsGroupMode
,
FmhaMask
,
Traits
>
{
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
};
template
<
typename
LSEDataType_
,
typename
OaccDataType_
,
typename
ODataType_
,
index_t
HeadDimV_
,
index_t
kM0_
,
index_t
kN1_
,
bool
kIsGroupMode_
,
typename
Traits_
>
struct
BlockFmhaSplitKVCombinePipelineProblem
{
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
256
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
index_t
kN1
=
kN1_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kMaxSplits
=
Traits
::
kMaxSplits
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
...
...
@@ -22,6 +23,7 @@ struct BlockFmhaPipelineQRKSVS
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
...
...
@@ -49,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -97,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
const
char
*
name
=
"qr"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
@@ -106,6 +111,7 @@ struct BlockFmhaPipelineQRKSVS
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
...
...
@@ -125,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
...
@@ -133,7 +140,8 @@ struct BlockFmhaPipelineQRKSVS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -240,6 +248,9 @@ struct BlockFmhaPipelineQRKSVS
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
...
...
@@ -475,6 +486,12 @@ struct BlockFmhaPipelineQRKSVS
});
});
if
constexpr
(
kHasDropout
)
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -589,6 +606,7 @@ struct BlockFmhaPipelineQRKSVS
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
...
...
@@ -596,11 +614,13 @@ struct BlockFmhaPipelineQRKSVS
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
...
...
@@ -610,6 +630,7 @@ struct BlockFmhaPipelineQRKSVS
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
@@ -618,7 +639,8 @@ struct BlockFmhaPipelineQRKSVS
mask
,
position_encoding
,
scale_s
,
smem_ptr
);
smem_ptr
,
dropout
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
...
...
@@ -23,6 +24,7 @@ struct BlockFmhaPipelineQRKSVSAsync
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
...
...
@@ -54,6 +56,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -78,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)
{
return
1
;
}
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
...
...
@@ -109,6 +118,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
const
char
*
name
=
"qr_async"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
@@ -118,6 +129,7 @@ struct BlockFmhaPipelineQRKSVSAsync
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
...
...
@@ -137,6 +149,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
...
@@ -145,7 +158,8 @@ struct BlockFmhaPipelineQRKSVSAsync
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -212,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
q_dram_window
.
init_raw
();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
...
...
@@ -285,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
k_dram_window
.
init_raw
();
constexpr
auto
k_oob_ck
=
bool_constant
<
true
>
{};
constexpr
auto
k_pre_np
=
[
&
]()
{
if
constexpr
(
kPadSeqLenK
&&
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)))
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
...
@@ -292,6 +318,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
...
...
@@ -299,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -322,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
);
k_dram_window
,
k_oob_ck
,
k_pre_np
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
...
...
@@ -558,8 +589,25 @@ struct BlockFmhaPipelineQRKSVSAsync
});
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
if
constexpr
(
kHasDropout
)
{
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
randval_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
const
auto
p
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
PDataType
,
fp16_t
>
)
return
impl
::
cast_tile_pk_fp16_fp32
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
else
return
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
}();
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
...
...
@@ -609,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync
{
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
k_dram_window
.
set_window_origin
(
k_dram_block_window
.
get_window_origin
());
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
...
...
@@ -688,6 +733,7 @@ struct BlockFmhaPipelineQRKSVSAsync
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
...
...
@@ -695,11 +741,13 @@ struct BlockFmhaPipelineQRKSVSAsync
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
...
...
@@ -709,6 +757,7 @@ struct BlockFmhaPipelineQRKSVSAsync
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
@@ -717,7 +766,8 @@ struct BlockFmhaPipelineQRKSVSAsync
mask
,
position_encoding
,
scale_s
,
smem_ptr
);
smem_ptr
,
dropout
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -22,6 +22,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
...
...
@@ -49,6 +50,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -106,6 +108,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
...
...
@@ -113,13 +116,15 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
/*randval_dram_block_window_tmp*/
,
// not supported
LSEDramBlockWindowTmp
&
/*lse_dram_window_tmp*/
,
// not supported
FmhaMask
mask
,
PositionEncoding
/*position_encoding*/
,
float
scale_s
,
float
descale_qk
,
float
descale_sv
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
BlockDropout
&
/*dropout*/
)
const
// not supported
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -21,6 +21,7 @@ struct BlockFmhaPipelineQSKSVS
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M
16N16K32
SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaF16F16F32M
32N32K16
SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M
16N16K32
SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaBf16Bf16F32M
32N32K16
SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
@@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M
16N16K32
SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaF16F16F32M
32N32K16
SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M
16N16K32
SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaBf16Bf16F32M
32N32K16
SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
@@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
KV
()
{
// TODO: assume Q is in register
// TODO: assume K/V has same data type
...
...
@@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
single_smem_size
*
max
(
NumPrefetchK
,
NumPrefetchV
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
if
constexpr
(
AsyncCopyK
)
{
return
GetSmemSizeKV
<
Problem
>
()
+
GetSmemSizeDropout
<
Problem
>
();
}
else
{
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
());
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeDropout
()
{
if
constexpr
(
Problem
::
kHasDropout
)
{
constexpr
auto
gemm_0
=
QXPolicy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
config
=
decltype
(
gemm_0
)
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
return
(
kMPerStep
+
1
)
*
kNPerStep
*
sizeof
(
uint8_t
);
}
else
{
return
0
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
da7ef7e8
...
...
@@ -43,4 +43,53 @@ struct TileFmhaShape
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
};
template
<
typename
BlockTile_
,
// sequence<...
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
typename
Gemm1BlockWarps_
,
typename
Gemm1WarpTile_
,
typename
Gemm2BlockWarps_
,
typename
Gemm2WarpTile_
,
typename
Gemm3BlockWarps_
,
typename
Gemm3WarpTile_
,
typename
Gemm4BlockWarps_
,
typename
Gemm4WarpTile_
>
struct
TileFmhaBwdShape
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
Gemm0BlockWarps
=
remove_cvref_t
<
Gemm0BlockWarps_
>
;
using
Gemm0WarpTile
=
remove_cvref_t
<
Gemm0WarpTile_
>
;
using
Gemm1BlockWarps
=
remove_cvref_t
<
Gemm1BlockWarps_
>
;
using
Gemm1WarpTile
=
remove_cvref_t
<
Gemm1WarpTile_
>
;
using
Gemm2BlockWarps
=
remove_cvref_t
<
Gemm2BlockWarps_
>
;
using
Gemm2WarpTile
=
remove_cvref_t
<
Gemm2WarpTile_
>
;
using
Gemm3BlockWarps
=
remove_cvref_t
<
Gemm3BlockWarps_
>
;
using
Gemm3WarpTile
=
remove_cvref_t
<
Gemm3WarpTile_
>
;
using
Gemm4BlockWarps
=
remove_cvref_t
<
Gemm4BlockWarps_
>
;
using
Gemm4WarpTile
=
remove_cvref_t
<
Gemm4WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
Gemm0BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static_assert
(
NumWarps
==
reduce_on_sequence
(
Gemm1BlockWarps
{},
multiplies
{},
number
<
1
>
{})
&&
NumWarps
==
reduce_on_sequence
(
Gemm4BlockWarps
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along gemm0(Q@K^T) unroll
static
constexpr
index_t
kK1
=
BlockTile
::
at
(
number
<
3
>
{});
// tile size along gemm1(P^T@dO) unroll
static
constexpr
index_t
kK2
=
BlockTile
::
at
(
number
<
4
>
{});
// tile size along gemm2(dO@V^T) unroll
static
constexpr
index_t
kK3
=
BlockTile
::
at
(
number
<
5
>
{});
// tile size along gemm3(dS^T@Q) unroll
static
constexpr
index_t
kK4
=
BlockTile
::
at
(
number
<
6
>
{});
// tile size along gemm4(dS@K) unroll
static
constexpr
index_t
kQKHeaddim
=
BlockTile
::
at
(
number
<
7
>
{});
// Q & K headdim, used for pipeline that need load Q/Q^T or
// K/K^T at once
static
constexpr
index_t
kVHeaddim
=
BlockTile
::
at
(
number
<
8
>
{});
// V headdim, used for pipeline
// that need load V at once
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -13,7 +13,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kStoreLSE_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaTraits
...
...
@@ -23,9 +25,65 @@ struct TileFmhaTraits
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ
/* padding for seqlen_q */
,
bool
kPadSeqLenK
/* padding for seqlen_k */
,
bool
kPadHeadDimQ
/* paddding for hdim_q */
,
bool
kPadHeadDimV
/* paddding for hdim_v */
,
BlockAttentionBiasEnum
BiasEnum
,
bool
kHasBiasGrad
,
bool
kStoreLSE
,
bool
kHasDropout
,
bool
kDoFp8StaticQuant
,
bool
kHasUnevenSplits_
=
true
,
index_t
kBlockPerCu
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVTraits
:
TileFmhaTraits
<
kPadSeqLenQ
,
kPadSeqLenK
,
kPadHeadDimQ
,
kPadHeadDimV
,
BiasEnum
,
kHasBiasGrad
,
kStoreLSE
,
kHasDropout
,
kDoFp8StaticQuant
,
kBlockPerCu
>
{
// determine if some split (length) is not divisible by tile size
static
constexpr
bool
kHasUnevenSplits
=
kHasUnevenSplits_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kStoreLSE_
,
bool
kDoFp8StaticQuant_
,
index_t
kLogMaxSplits_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVCombineTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kMaxSplits
=
(
1
<<
kLogMaxSplits_
);
static_assert
(
kMaxSplits
<=
get_warp_size
()
||
kMaxSplits
%
get_warp_size
()
==
0
);
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
struct
TileFmhaBwdOGradDotOTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
Prev
1
…
7
8
9
10
11
12
13
14
15
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment