Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bd689f40
Commit
bd689f40
authored
Aug 20, 2024
by
illsilin
Browse files
merge from public repo
parents
c160c6cf
a94113a9
Changes
333
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2982 additions
and
1352 deletions
+2982
-1352
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+9
-0
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+3
-2
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+4
-1
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+33
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+3
-8
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+377
-16
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+555
-332
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+0
-54
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+2
-4
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+21
-18
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+12
-13
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
...e/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
+141
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
...ude/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
+3
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+782
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+1037
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
.../fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
+0
-821
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
+0
-20
No files found.
include/ck_tile/core/numeric/vector_type.hpp
View file @
bd689f40
...
...
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
// u32
// using uint32_t = ...
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint32x16_t
=
uint32_t
__attribute__
((
ext_vector_type
(
16
)));
using
uint32x32_t
=
uint32_t
__attribute__
((
ext_vector_type
(
32
)));
using
uint32x64_t
=
uint32_t
__attribute__
((
ext_vector_type
(
64
)));
// i16
// using int16_t = ...
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
bd689f40
...
...
@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return
make_tuple
(
make_static_tile_distribution
(
tile_distribution_encoding
<
typename
Encoding
::
RsLengths
,
decltype
(
sliced_h_lengths
),
// only need to change the
// h_lengths type
remove_cvref_t
<
decltype
(
sliced_h_lengths
)
>
,
// only need to
// change the
// h_lengths type
typename
Encoding
::
Ps2RHssMajor
,
typename
Encoding
::
Ps2RHssMinor
,
typename
Encoding
::
Ys2RHsMajor
,
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
bd689f40
...
...
@@ -393,7 +393,10 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
...
...
include/ck_tile/core/utility/philox_rand.hpp
View file @
bd689f40
...
...
@@ -53,6 +53,39 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
CK_TILE_HOST_DEVICE
void
get_random_8x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
out_tmp
[
1
]
=
tmp
[
start_idx
+
2
];
}
CK_TILE_HOST_DEVICE
void
get_random_4x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
}
private:
struct
ull2
{
...
...
include/ck_tile/ops/fmha.hpp
View file @
bd689f40
...
...
@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
bd689f40
...
...
@@ -286,11 +286,226 @@ struct BlockDropout
});
}
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
template
<
bool
IsDropout_
,
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
;
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
<
false
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
false
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
__host__
__device__
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
};
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
<
true
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
true
;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static
constexpr
bool
IsWG32
=
IsWG32_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
CK_TILE_HOST_DEVICE
BlockDropoutBwd
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
(
IsWG32
?
get_lane_id
()
:
((
get_lane_id
()
&
47
)
+
((
get_warp_id
()
&
1
)
<<
4
)))),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
)
{
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
return
randval_dram_window
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
constexpr
index_t
kN1
=
8
;
constexpr
index_t
kN0
=
kNPerStep
/
kN1
;
constexpr
auto
randval_lds_block_desc_0
=
make_naive_tensor_descriptor
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kMPerStep
>
{},
number
<
kN1
>
{}),
ck_tile
::
make_tuple
(
number
<
(
kMPerStep
+
1
)
*
kN1
>
{},
number
<
kN1
>
{},
number
<
1
>
{}),
number
<
kN1
>
{},
number
<
1
>
{});
constexpr
auto
randval_lds_block_desc
=
transform_tensor_descriptor
(
randval_lds_block_desc_0
,
ck_tile
::
make_tuple
(
make_pass_through_transform
(
number
<
kMPerStep
>
{}),
make_merge_transform
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kN1
>
{}))),
ck_tile
::
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
ck_tile
::
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
randval_lds_block_desc
;
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
MIterPerWarp
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
2
;
}
else
{
return
1
;
}
}();
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
// except headdim256.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaF16F16F32M16N16K16
::
CWarpDstrEncoding
{};
}
else
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaBf16Bf16F32M16N16K16
::
CWarpDstrEncoding
{};
}
}();
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
randval_block_inner_part_dstr_encoding
);
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsShuffleTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
,
typename
PComputeDataType
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
...
...
@@ -305,30 +520,177 @@ struct BlockDropout
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// randval tile in LDS
auto
randval_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
uint8_t
*>
(
randval_ptr
),
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
());
auto
randval_lds_window
=
make_tile_window
(
randval_lds
,
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
().
get_lengths
(),
{
0
,
0
});
// register distribute
auto
randval
=
auto
randval
_dist_generated
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
static_assert
(
randval
_dist_generated
.
kThreadElementSpaceSize
==
16
);
const
int
start_n0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
1
>
{});
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
auto
randval_lds_read_window
=
make_tile_window
(
randval_lds_window
.
get_bottom_tensor_view
(),
randval_lds_window
.
get_window_lengths
(),
randval_lds_window
.
get_window_origin
(),
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
,
idx1
.
impl_
.
at
(
1
),
idx1
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
?
p_compute
[
p_idx
]
*
rp_undrop
:
PComputeDataType
(
0
);
});
});
// save to Global
if
constexpr
(
IsStoreRandval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
}
}
template
<
typename
BlockGemm
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
bool
MBwdWG16SingleIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
==
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
,
false
>
());
if
constexpr
(
IsWG32
)
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
else
static_assert
(
randval
.
kThreadElementSpaceSize
==
4
||
randval
.
kThreadElementSpaceSize
==
8
);
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
,
block_col_start
;
if
constexpr
(
IsWG32
)
{
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
}
else
{
block_row_start
=
start_m0_idx
/
32
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
32
)
+
get_warp_id
()
/
2
+
i_n0
*
2
;
}
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
*
random_uint8_t_
;
if
constexpr
(
MBwdWG16SingleIterCheck
)
{
uint8_t
random_uint8_t
[
4
];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const
index_t
start_idx
=
((
get_lane_id
()
>>
4
)
&
1
)
+
(((
start_m0_idx
>>
4
)
&
1
)
<<
1
);
ph
.
get_random_4x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
uint8_t
random_uint8_t
[
8
];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const
index_t
start_idx
=
(
get_lane_id
()
>>
4
)
&
1
;
ph
.
get_random_8x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
{
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
random_uint8_t_
=
random_uint8_t
;
}
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t_
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
+
idx0
.
impl_
.
at
(
0
),
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
...
...
@@ -337,19 +699,19 @@ struct BlockDropout
});
});
// save to Global
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
}
...
...
@@ -358,7 +720,6 @@ struct BlockDropout
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
bd689f40
...
...
@@ -23,13 +23,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
struct
FmhaBwdDQDKDVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
...
...
@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
static
constexpr
bool
kIsDeterministic
=
FmhaPipeline
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -73,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
{
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr0
=
typename
bfs
::
Gemm0BlockWarps
;
using
gbr1
=
typename
bfs
::
Gemm1BlockWarps
;
using
gbr4
=
typename
bfs
::
Gemm4BlockWarps
;
using
gwt0
=
typename
bfs
::
Gemm0WarpTile
;
using
gwt1
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
...
...
@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel
return
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK2
)
+
"x"
+
_TS_
(
bfs
::
kK3
)
+
"x"
+
_TS_
(
bfs
::
kK4
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
);
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel
const
void
*
lse_ptr
;
const
void
*
do_ptr
;
const
void
*
d_ptr
;
void
*
dq_ptr
;
void
*
dq_
acc_
ptr
;
void
*
dk_ptr
;
void
*
dv_ptr
;
...
...
@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
nhead_ratio_qk
;
float
raw_scale
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
;
#endif
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
...
...
@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
};
struct
FmhaBwdCommonBiasKargs
...
...
@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdBatchModeKargs
:
FmhaBwdCommonKargs
,
...
...
@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
};
...
...
@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
...
...
@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
...
...
@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
...
...
@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
...
...
@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
seqlen_q
,
...
...
@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
...
...
@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dv
};
...
...
@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
...
...
@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
...
...
@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch
_stride_
lsed
,
ck_tile
::
index_t
split
_stride_
dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
-
1
,
// seqlen will be updated by another pointer
...
...
@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
...
...
@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
...
...
@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_k_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
FmhaPipeline
::
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_k
);
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
...
...
@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dbias
=
0
;
...
...
@@ -557,13 +608,14 @@ struct FmhaBwdDQDKDVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
query_start
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
;
...
...
@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
key_start
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -603,13 +655,14 @@ struct FmhaBwdDQDKDVKernel
}
else
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
...
...
@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
k
+
batch_offset_dk
;
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_v
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
v
+
batch_offset_dv
;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
...
...
@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
const
auto
q_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
qt_dram_naive
=
transform_tensor_view
(
q_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
qt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
}();
const
auto
q_dram
=
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
...
...
@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
const
auto
kt_dram_naive
=
transform_tensor_view
(
k_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
kt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
}();
const
auto
k_dram
=
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}();
const
auto
lse_dram
=
[
&
]()
{
...
...
@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
const
auto
do_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
}();
const
auto
dot_dram_naive
=
transform_tensor_view
(
do_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
dot_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
const
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
const
auto
do_dram
=
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
qt_dram_window
=
make_tile_window
(
qt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{});
}(),
{
0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
i_n0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{});
}(),
{
0
,
i_n0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
i_n0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
0
,
0
});
auto
dot_dram_window
=
make_tile_window
(
dot_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
}(),
{
0
,
0
});
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
dq_dram_window
=
[
&
,
i_tile_n_
=
i_tile_n
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
static_cast
<
long_index_t
>
(
i_tile_n_
)
*
kargs
.
split_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
else
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
}();
auto
lse_dram_window
=
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
...
...
@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read?
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
...
...
@@ -1033,35 +945,34 @@ struct FmhaBwdDQDKDVKernel
}();
// dropout
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
else
{
return
FmhaDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel
}();
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
qt_dram_window
,
k_dram_window
,
kt_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
do_dram_window
,
dot_dram_window
,
lse_dram_window
,
d_dram_window
,
dq_dram_window
,
...
...
@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel
mask
,
position_encoding
,
kargs
.
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs
.
scale
,
#endif
rp_undrop
,
scale_rp_undrop
,
smem_ptr
,
...
...
@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdOGradDotO_
>
template
<
typename
FmhaBwdOGradDotO_
>
struct
FmhaBwdOGradDotOKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
...
...
@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
...
...
@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
nhead_stride_d
},
batch_stride_do
,
batch_stride_o
};
batch_stride_o
,
batch_stride_d
};
return
kargs
;
}
...
...
@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
batch_stride_d
)
ck_tile
::
index_t
nhead_stride_d
)
{
Kargs
kargs
{{
o_ptr
,
do_ptr
,
...
...
@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
nhead_stride_d
},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
return
kargs
;
...
...
@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_q
);
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
...
...
@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
batch_offset_d
=
query_start
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel
}
};
template
<
typename
FmhaBwdConvertQGrad_
>
struct
FmhaBwdConvertQGradKernel
{
using
FmhaBwdConvertQGrad
=
ck_tile
::
remove_cvref_t
<
FmhaBwdConvertQGrad_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdConvertQGrad
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdConvertQGrad
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kM0
=
FmhaBwdConvertQGrad
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
FmhaBwdConvertQGrad
::
kN0
;
static
constexpr
ck_tile
::
index_t
kQKHeaddim
=
FmhaBwdConvertQGrad
::
kQKHeaddim
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
AccDataType
>
;
using
QGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
QGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaBwdConvertQGrad
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaBwdConvertQGrad
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaBwdConvertQGrad
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
FmhaBwdConvertQGrad
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_convert_dq_d"
)
+
_TS_
(
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QGradDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
)
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
))
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template
<
ck_tile
::
index_t
I
>
struct
FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdConvertQGradCommonKargs
{
const
void
*
dq_acc_ptr
;
void
*
dq_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdConvertQGradBatchModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradGroupModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdConvertQGradGroupModeKargs
,
FmhaBwdConvertQGradBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dq
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
batch_stride_dq
,
batch_stride_dq_acc
};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
)};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
GetTileIndex
();
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_dq
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_dq
=
query_start
*
kargs
.
stride_dq
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
if
constexpr
(
kIsDeterministic
)
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_dq
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_dq
+
batch_offset_dq
;
// dQAcc/dQ DRAM and DRAM window
const
auto
dq_acc_dram
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
nsplits
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
dq_acc_dram_window
=
[
&
]()
{
if
constexpr
(
kIsDeterministic
)
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
i_m0
,
0
});
}
else
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
}
}();
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
if
constexpr
(
kIsDeterministic
)
{
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
,
nsplits
);
}
else
{
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
c160c6cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaBwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
kBlockSize
>
struct
FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kBlockSize
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
bd689f40
...
...
@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
...
...
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
...
...
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
}
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
query_start
;
}
if
constexpr
(
kHasDropout
)
{
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
bd689f40
...
...
@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
...
...
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
...
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
struct
GroupModeKargs
...
...
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
batch_stride_o
};
batch_stride_o
,
batch_stride_lse_acc
};
if
constexpr
(
kStoreLSE
)
{
...
...
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
{
...
...
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
query_start
;
}
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
// for simplicity, batch stride we just modify the pointer
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
bd689f40
...
...
@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
...
...
@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
...
@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
struct
GroupModeKargs
...
...
@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{},
// placeholder for dropout
batch_stride_q
,
batch_stride_k
,
batch_stride_v
};
batch_stride_v
,
batch_stride_lse_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
...
...
@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
...
...
@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
...
@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
}
else
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
0 → 100644
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdConvertQGrad
{
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
static
constexpr
index_t
kM0
=
Problem
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
kN0
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
kQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
static
constexpr
index_t
kAlignmentQGradAcc
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGradAcc
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGrad
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
// Convert only
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradDramTileDistribution
<
Problem
>());
auto
dq_acc
=
load_tile
(
dq_acc_dram_window
);
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
// Reduce + Convert
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
index_t
nsplits
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>());
auto
dq_acc
=
decltype
(
load_tile
(
dq_acc_dram_window
)){};
clear_tile
(
dq_acc
);
constexpr
auto
dq_acc_spans
=
decltype
(
dq_acc
)
::
get_distributed_spans
();
index_t
i_total_loops
=
0
;
auto
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
do
{
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
i_total_loops
+=
1
;
}
while
(
i_total_loops
<
(
nsplits
-
1
));
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
// declare dq
constexpr
auto
dq_converted_dstr
=
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>();
auto
dq_converted
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_converted_dstr
);
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_converted
(
n_i_j_idx
)
=
type_convert
<
QGradDataType
>
(
dq_acc
[
n_i_j_idx
]);
});
});
});
constexpr
auto
dq_dstr
=
Policy
::
template
MakePostQGradDramTileDistribution
<
Problem
>();
auto
dq
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_dstr
);
dq
.
get_thread_buffer
()
=
dq_converted
.
get_thread_buffer
();
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
View file @
bd689f40
...
...
@@ -4,11 +4,11 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dot_do_o
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
OGradDotO
DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
Pipeline
DefaultPolicy
>
struct
BlockFmhaBwdOGradDotO
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
...
...
@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
Grad
<
Problem
>();
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
deleted
100644 → 0
View file @
c160c6cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// These templates are not used here.
using
BlockFmhaBwdOGradDotODefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
false
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
false
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_
qs_ks_vr_dos
.hpp
→
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_
kr_ktr_vr
.hpp
View file @
bd689f40
...
...
@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dq_dk_dv_pipeline_qs_ks_vr_dos
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
DQDKDV
Pipeline
QSKSVROGradS
DefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipeline
QSKSVROGradS
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipeline
KRKTRVR
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
...
...
@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
HotLoopScheduler
=
typename
Policy
::
template
HotLoopScheduler
<
Problem
>;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
...
@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
true
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
false
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
true
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
1
;
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
...
...
@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"
qs_ks_vr_dos
"
;
static
constexpr
const
char
*
name
=
"
kr_ktr_vr
"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
...
...
@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
/*qt_dram_block_window_tmp*/
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
/*kt_dram_block_window_tmp*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
/*dot_dram_block_window_tmp*/
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
...
...
@@ -122,13 +107,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -138,9 +121,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
...
...
@@ -156,77 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// QT tile in LDS
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptorAsQT
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptorAsKT
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
{
0
,
0
});
// OGradT tile in LDS
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptorAsOGradT
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
...
...
@@ -234,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
// K, HBM ->LDS ->Reg
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
// Early termination
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
...
...
@@ -274,217 +169,408 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
}
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKRegSliceBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
VDataType
*
v_lds_ptr
=
static_cast
<
VDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto
shuffled_k_block_tile
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKRegWriteBlockDescriptor
<
Problem
>());
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
shuffled_k_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
auto
kt_lds_read_window
=
make_tile_window
(
kt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeKTRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
v_block_tile
=
load_tile
(
v_dram_window
);
store_tile
(
k_lds_write_window
,
k_block_tile
);
shuffle_tile
(
shuffled_k_block_tile
,
k_block_tile
);
store_tile
(
shuffled_k_lds_write_window
,
shuffled_k_block_tile
);
block_sync_lds
();
k_reg_tensor
=
load_tile
(
k_lds_read_window
);
block_sync_lds
();
auto
kt_reg_tensor
=
load_tile
(
kt_lds_read_window
);
store_tile
(
k
_lds_window
,
k
_block_tile
);
// // persistent K in LDS
store_tile
(
v
_lds_
write_
window
,
v
_block_tile
);
auto
q_dram_block_window
=
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
block_sync_lds
();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
q_lds_window
.
get_window_origin
(),
Policy
::
template
MakeQRegSliceBlockDescriptor
<
Problem
>());
auto
pt_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakePTRegSliceBlockDescriptor
<
Problem
>());
// QT: Reg -> Reg-> LDS
auto
shuffled_q_block_tile
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQRegWriteBlockDescriptor
<
Problem
>());
auto
do_dram_block_window
=
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
shuffled_q_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
auto
qt_lds_read_window
=
make_tile_window
(
qt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeQTRegSliceBlockDescriptor
<
Problem
>());
// dO: HBM ->Reg ->LDS
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
do_lds_window
.
get_window_origin
(),
Policy
::
template
MakeOGradRegSliceBlockDescriptor
<
Problem
>());
// dOT: Reg ->Reg ->LDS
auto
shuffled_do_block_tile
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradRegWriteBlockDescriptor
<
Problem
>());
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
shuffled_do_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
auto
dot_lds_read_window
=
make_tile_window
(
dot_read_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeOGradTRegSliceBlockDescriptor
<
Problem
>());
// dS: Reg -> Reg -> LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
ds_lds_read_window
=
make_tile_window
(
ds_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK4
>
{}),
ds_lds_window
.
get_window_origin
(),
Policy
::
template
MakeSGradRegSliceBlockDescriptor
<
Problem
>());
auto
dst_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakeSGradTRegSliceBlockDescriptor
<
Problem
>());
// Bias: HBM ->Reg ->Reg ->LDS
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})},
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
BiasDataType
*
bias_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
bias_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
bias_lds_ptr
,
Policy
::
template
MakeBiasLdsBlockDescriptor
<
Problem
>());
auto
bias_lds_write_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
bias_s_lds_read_window
=
make_tile_window
(
bias_lds_write_window
.
get_bottom_tensor_view
(),
bias_lds_write_window
.
get_window_lengths
(),
bias_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeBiasSTileDistribution
<
decltype
(
gemm_0
)>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// LSE: HBM -> LDS ->Reg
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
l
se
_dram_block_window
.
get_window_origin
()
,
lse_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window
_tmp
.
get_window_lengths
(),
{
se
qlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
LSEDataType
*
lse_lds_ptr
=
static_cast
<
LSEDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
lse_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
lse_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_lds_write_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
lse_lds_read_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// D: HBM ->Reg
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
()
,
d_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window
_tmp
.
get_window_lengths
(),
{
seqlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
DDataType
*
d_lds_ptr
=
static_cast
<
DDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()));
auto
d_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
d_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
d_lds_write_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
()
,
biast_lds_shuffle_window
.
get_window_lengths
(
),
biast_lds_shuffle_window
.
get_window_origin
()
,
Policy
::
template
Make
BiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
d_lds_read_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}
),
{
0
}
,
Policy
::
template
Make
LSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
// RandVal: HBM ->Reg
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
dbias_lds_read_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
// ----------------------------Loop write out------------------------------//
auto
dq_dram_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
using
SPBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
==
kK0
,
"kQKHeaddim should equal to kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
==
kK2
,
"kVHeaddim should equal to kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{}
;
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
)
;
__builtin_amdgcn_sched_barrier
(
0
);
// Hot loop
while
(
i_total_loops
<
num_total_loop
)
{
auto
q_block_tile
=
load_tile
(
q_dram_window
);
clear_tile
(
st_acc
);
// Initialize S^T
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
get_slice_tile
(
q_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
});
}
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
get_slice_tile
(
q_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
block_sync_lds
();
}
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
block_sync_lds
();
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
block_sync_lds
();
// STAGE 1, Q@K Gemm0
auto
s_acc
=
SPBlockTileType
{};
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
(
);
auto
bias_
shuffle
_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffle
d_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_
shuffle
_tmp
,
bias_tile
);
store_tile
(
bias
t
_lds_
shuffl
e_window
,
bias_
shuffle
_tmp
);
shuffle_tile
(
shuffle
d_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_
writ
e_window
,
shuffle
d_bias_tile
);
block_sync_lds
();
auto
bias
t
_tile
=
load_tile
(
bias
t
_lds_window
);
auto
bias
_s
_tile
=
load_tile
(
bias
_s
_lds_
read_
window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
s
t
_acc
,
bias
t
_tile
);
s_acc
,
bias
_s
_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s
t
_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s
t
_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
...
...
@@ -499,157 +585,162 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
}
};
auto
p
t
=
SP
T
BlockTileType
{};
constexpr
auto
p
t
_spans
=
decltype
(
p
t
)
::
get_distributed_spans
();
sweep_tile_span
(
p
t
_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
t
(
i_j_idx
)
=
exp2
(
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
t
(
i_j_idx
)
=
exp2
(
scale
*
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_st
art
+
i_total_loops
*
kM0
,
p
t
,
randval_dram_window
);
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_st
ep
,
k_origin
.
at
(
number
<
0
>
{})
,
p
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
// store the prefetch
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
t
);
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
t
);
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
static_for
<
0
,
k1_loops
,
1
>
{}([
&
](
auto
i_k1
)
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
get_slice_tile
(
dot_lds_window
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kVHeaddim
,
(
i_k1
+
1
)
*
kK1
>
{}));
block_sync_lds
();
});
// STAGE 3, P^T@OGrad^T Gemm1
auto
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
store_tile
(
d_lds_write_window
,
d_block_tile
);
block_sync_lds
();
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
block_sync_lds
();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
block_sync_lds
();
static_for
<
0
,
k2_loops
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
get_slice_tile
(
do_lds_window
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kM0
,
(
i_k2
+
1
)
*
kK2
>
{}),
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
});
auto
dp_acc
=
SPGradBlockTileType
{};
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds
t
_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
t
[
i_j_idx
]
>=
0
;
ds
t
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
t
=
[
&
]()
{
if
constexpr
(
kHa
sDropout
)
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
t
);
ds
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
t
);
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias
t
_lds_
shuffl
e_window
,
dbias
t
);
store_tile
(
bias_lds_
writ
e_window
,
dbias
);
block_sync_lds
();
auto
dbias
t
_tile
=
load_tile
(
dbias
t
_lds_
shuffle
_window
);
auto
dbias
t_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
auto
shuffled_
dbias_tile
=
load_tile
(
dbias_lds_
read
_window
);
auto
dbias
_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
block_sync_lds
();
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
static_for
<
0
,
k3_loops
,
1
>
{}([
&
](
auto
i_k3
)
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
get_slice_tile
(
qt_lds_window
,
sequence
<
0
,
i_k3
*
kK3
>
{},
sequence
<
kQKHeaddim
,
(
i_k3
+
1
)
*
kK3
>
{}));
block_sync_lds
();
});
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
// STAGE7 SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
});
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
// QGrad Scale
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
...
...
@@ -658,34 +749,33 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
if
constexpr
(
kIsDeterministic
)
{
store_tile
(
dq_dram_window
,
dq_acc
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
}
//
KGrad
Scale
if
constexpr
(
kHa
sDropout
)
//
Results
Scale
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
s
_kt
s
_vr.hpp
→
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
r
_kt
r
_vr
_iglp
.hpp
View file @
bd689f40
...
...
@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dq_dk_dv_pipeline_ks_kts_vr
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
DQDKDV
Pipeline
KSKTSVR
DefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineK
S
KT
S
VR
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineK
R
KT
R
VR
IGLP
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
...
...
@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
HotLoopScheduler
=
typename
Policy
::
template
HotLoopScheduler
<
Problem
>;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
...
@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
true
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
false
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
1
;
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
...
...
@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"k
s
_kt
s
_vr"
;
static
constexpr
const
char
*
name
=
"k
r
_kt
r
_vr
_iglp
"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
...
...
@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
qt_dram_block_window_tmp
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
kt_dram_block_window_tmp
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
dot_dram_block_window_tmp
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
...
...
@@ -122,43 +107,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
QTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
KTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kVHeaddim
==
OGradTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
...
...
@@ -166,83 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// QT tile in LDS
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsBlockDescriptor
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kK3
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsBlockDescriptor
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
// OGradT tile in LDS
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsBlockDescriptor
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kK1
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
...
...
@@ -250,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
// K, HBM ->LDS ->Reg
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
// Early termination
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
...
...
@@ -290,272 +169,444 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
}
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKRegSliceBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
VDataType
*
v_lds_ptr
=
static_cast
<
VDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto
shuffled_k_block_tile
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKRegWriteBlockDescriptor
<
Problem
>());
auto
kt_dram_block_window
=
kt_dram_block_window_tmp
;
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
shuffled_k_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
auto
kt_lds_read_window
=
make_tile_window
(
kt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeKTRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
v_block_tile
=
load_tile
(
v_dram_window
);
auto
kt_dram_window
=
make_tile_window
(
kt_dram_block_window
.
get_bottom_tensor_view
(),
kt_dram_block_window
.
get_window_lengths
(),
kt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKTDramTileDistribution
<
Problem
>());
// K^T DRAM tile window for
// load
store_tile
(
k_lds_write_window
,
k_block_tile
);
shuffle_tile
(
shuffled_k_block_tile
,
k_block_tile
);
store_tile
(
shuffled_k_lds_write_window
,
shuffled_k_block_tile
);
auto
kt_block_tile
=
load_tile
(
kt_dram_window
);
block_sync_lds
();
k_reg_tensor
=
load_tile
(
k_lds_read_window
);
block_sync_lds
();
auto
kt_shuffle_tmp
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKTRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
kt_shuffle_tmp
,
kt_block_tile
);
auto
kt_reg_tensor
=
load_tile
(
kt_lds_read_window
);
store_tile
(
kt
_lds_window
,
kt_shuffle_tmp
);
// persistent K^T in LDS
store_tile
(
v
_lds_
write_
window
,
v_block_tile
);
auto
q_dram_block_window
=
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
q_lds_window
.
get_window_origin
(),
Policy
::
template
MakeQRegSliceBlockDescriptor
<
Problem
>());
auto
qt_dram_block_window
=
make_tile_window
(
qt_dram_block_window_tmp
.
get_bottom_tensor_view
(),
qt_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
pt_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakePTRegSliceBlockDescriptor
<
Problem
>());
// QT: Reg -> Reg-> LDS
auto
shuffled_q_block_tile
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQRegWriteBlockDescriptor
<
Problem
>());
auto
do_dram_block_window
=
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
shuffled_q_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
auto
qt_lds_read_window
=
make_tile_window
(
qt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeQTRegSliceBlockDescriptor
<
Problem
>());
// dO: HBM ->Reg ->LDS
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
auto
dot_dram_block_window
=
make_tile_window
(
dot_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dot_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
do_lds_window
.
get_window_origin
(),
Policy
::
template
MakeOGradRegSliceBlockDescriptor
<
Problem
>());
// dOT: Reg ->Reg ->LDS
auto
shuffled_do_block_tile
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradRegWriteBlockDescriptor
<
Problem
>());
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
shuffled_do_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
auto
dot_lds_read_window
=
make_tile_window
(
dot_read_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeOGradTRegSliceBlockDescriptor
<
Problem
>());
// dS: Reg -> Reg -> LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
ds_lds_read_window
=
make_tile_window
(
ds_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK4
>
{}),
ds_lds_window
.
get_window_origin
(),
Policy
::
template
MakeSGradRegSliceBlockDescriptor
<
Problem
>());
auto
dst_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakeSGradTRegSliceBlockDescriptor
<
Problem
>());
// Bias: HBM ->Reg ->Reg ->LDS
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})},
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
BiasDataType
*
bias_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
bias_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
bias_lds_ptr
,
Policy
::
template
MakeBiasLdsBlockDescriptor
<
Problem
>());
auto
qt_dram_window
=
make_tile_window
(
qt_dram_block_window
.
get_bottom_tensor_view
(),
qt_dram_block_window
.
get_window_lengths
(),
qt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQTDramTileDistribution
<
Problem
>());
auto
bias_lds_write_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dot_dram
_window
=
make_tile_window
(
dot_dram_block
_window
.
get_bottom_tensor_view
(),
dot_dram_block
_window
.
get_window_lengths
(),
dot_dram_block
_window
.
get_window_origin
(),
Policy
::
template
Make
OGradTDram
TileDistribution
<
Problem
>());
auto
bias_s_lds_read
_window
=
make_tile_window
(
bias_lds_write
_window
.
get_bottom_tensor_view
(),
bias_lds_write
_window
.
get_window_lengths
(),
bias_lds_write
_window
.
get_window_origin
(),
Policy
::
template
Make
BiasS
TileDistribution
<
decltype
(
gemm_0
)
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// LSE: HBM -> LDS ->Reg
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
l
se
_dram_block_window
.
get_window_origin
()
,
lse_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window
_tmp
.
get_window_lengths
(),
{
se
qlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
LSEDataType
*
lse_lds_ptr
=
static_cast
<
LSEDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
lse_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
lse_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_lds_write_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
lse_lds_read_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// D: HBM ->Reg
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
()
,
d_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window
_tmp
.
get_window_lengths
(),
{
seqlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
DDataType
*
d_lds_ptr
=
static_cast
<
DDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>())
)
;
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
d_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
d_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
auto
d_lds_write_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
d_lds_read_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// RandVal: HBM ->Reg
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
dbias_lds_read_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
// ----------------------------Loop write out------------------------------//
auto
dq_dram_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
using
SPBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
==
kK0
,
"kQKHeaddim should equal to kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
==
kK2
,
"kVHeaddim should equal to kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
/*
* Prefetch Q, LSE, dO, D
*/
auto
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
auto
q_block_tile
=
load_tile
(
q_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
auto
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
clear_tile
(
st_acc
);
// Initialize S^T
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write 0
q_block_tile
=
load_tile
(
q_dram_window
);
// global read 1
}
/*
* Store prefetched data into LDS
*/
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
});
}
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
const
auto
dot_prefetch
=
load_tile
(
dot_dram_window
);
// prefetch load OGrad^T tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kN0
,
(
k0_loops
-
1
)
*
kK0
>
{}));
block_sync_lds
();
store_tile
(
d_lds_write_window
,
d_block_tile
);
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
block_sync_lds
();
/*
* Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
}
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
__builtin_amdgcn_sched_barrier
(
0
);
// Hot loop
while
(
i_total_loops
<
(
num_total_loop
-
1
))
{
// STAGE 1, Q@K Gemm0
auto
s_acc
=
SPBlockTileType
{};
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
0
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
(
);
auto
bias_
shuffle
_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffle
d_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_
shuffle
_tmp
,
bias_tile
);
store_tile
(
bias
t
_lds_
shuffl
e_window
,
bias_
shuffle
_tmp
);
shuffle_tile
(
shuffle
d_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_
writ
e_window
,
shuffle
d_bias_tile
);
block_sync_lds
();
auto
bias
t
_tile
=
load_tile
(
bias
t
_lds_window
);
auto
bias
_s
_tile
=
load_tile
(
bias
_s
_lds_
read_
window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
s
t
_acc
,
bias
t
_tile
);
s_acc
,
bias
_s
_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s
t
_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s
t
_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
...
...
@@ -570,278 +621,416 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
};
auto
p
t
=
SP
T
BlockTileType
{};
constexpr
auto
p
t
_spans
=
decltype
(
p
t
)
::
get_distributed_spans
();
sweep_tile_span
(
p
t
_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
t
(
i_j_idx
)
=
exp2
(
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
t
(
i_j_idx
)
=
exp2
(
scale
*
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
auto
dot_shuffle_tmp
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
dot_shuffle_tmp
,
dot_prefetch
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
if
constexpr
(
kHasDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_st
art
+
i_total_loops
*
kM0
,
p
t
,
randval_dram_window
);
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_st
ep
,
k_origin
.
at
(
number
<
0
>
{})
,
p
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
t
);
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
t
);
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
dot
=
load_tile
(
dot_dram_window
);
// load next OGrad^T
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
shuffle_tile
(
dot_shuffle_tmp
,
dot
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
});
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
// tail
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
(
k1_loops
-
1
)
*
kK1
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
}
// STAGE 3, P^T@OGrad^T Gemm1
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
{
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
clear_tile
(
dp
t
_acc
);
// Initialize PGrad^T
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write 0
do_block_tile
=
load_tile
(
do_dram_window
);
// global read 1
}
block_sync_lds
();
if
constexpr
(
k2_loops
>
2
)
{
static_for
<
0
,
k2_loops
-
2
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write i + 1
do_block_tile
=
load_tile
(
do_dram_window
);
// global read i + 2
});
}
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
const
auto
qt_prefetch
=
load_tile
(
qt_dram_window
);
// prefetch load Q^T tile
{
// tail
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
2
)
*
kK2
>
{},
sequence
<
kN0
,
(
k2_loops
-
1
)
*
kK2
>
{}));
block_sync_lds
();
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
store_tile
(
do_lds_window
,
do_block_tile
);
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
1
)
*
kK2
>
{},
sequence
<
kN0
,
k2_loops
*
kK2
>
{}));
}
store_tile
(
d_lds_write_window
,
d_block_tile
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds
t
_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
t
[
i_j_idx
]
>=
0
;
ds
t
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
t
=
[
&
]()
{
if
constexpr
(
kHa
sDropout
)
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
t
);
ds
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
t
);
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias
t
_lds_
shuffl
e_window
,
dbias
t
);
store_tile
(
bias_lds_
writ
e_window
,
dbias
);
block_sync_lds
();
auto
dbias
t
_tile
=
load_tile
(
dbias
t
_lds_
shuffle
_window
);
auto
dbias
t_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
auto
shuffled_
dbias_tile
=
load_tile
(
dbias_lds_
read
_window
);
auto
dbias
_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_shuffle_tmp
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQTRegBlockDescriptor
<
Problem
>());
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
lse
=
load_tile
(
lse_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE7 SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
d
=
load_tile
(
d_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
// QGrad Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
if
constexpr
(
kIsDeterministic
)
{
shuffle_tile
(
qt_shuffle_tmp
,
qt_prefetch
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
store_tile
(
dq_dram_window
,
dq_acc
);
}
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
}
__builtin_amdgcn_sched_barrier
(
0
);
// Tail
auto
s_acc
=
SPBlockTileType
{};
// STAGE 1, Q@K Gemm0
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffled_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
shuffled_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_write_window
,
shuffled_bias_tile
);
block_sync_lds
();
auto
bias_s_tile
=
load_tile
(
bias_s_lds_read_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
},
s_acc
,
bias_s_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
constexpr
(
k3_loops
>
1
)
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
static_for
<
0
,
k3_loops
-
1
,
1
>
{}([
&
](
auto
i_k3
)
{
const
auto
qt
=
load_tile
(
qt_dram_window
);
// load next Q^T
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
shuffle_tile
(
qt_shuffle_tmp
,
qt
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
// tail
}
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
(
k3_loops
-
1
)
*
kK3
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
});
});
block_sync_lds
();
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
p
,
randval_dram_window
);
}
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
// QGrad Scale
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias_lds_write_window
,
dbias
);
block_sync_lds
();
auto
shuffled_dbias_tile
=
load_tile
(
dbias_lds_read_window
);
auto
dbias_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
}
);
}
else
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
(
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
});
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
//
KGrad
Scale
if
constexpr
(
kHa
sDropout
)
//
Results
Scale
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsDeterministic
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
store_tile
(
dq_dram_window
,
dq_acc
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
deleted
100644 → 0
View file @
c160c6cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k & k^t located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
true
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
deleted
100644 → 0
View file @
c160c6cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineKSVR
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
using
KGradDataType
=
remove_cvref_t
<
typename
Problem
::
KGradDataType
>
;
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK2
=
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK3
=
BlockFmhaShape
::
kK3
;
static
constexpr
index_t
kK4
=
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
false
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
false
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentVGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"ks_vr"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
typename
BiasGradDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
qt_dram_block_window_tmp
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
/*kt_dram_block_window_tmp*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
dot_dram_block_window_tmp
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
const
BiasGradDramBlockWindowTmp
&
dbias_dram_block_window_tmp
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
QTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kVHeaddim
==
OGradTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// QT tile in LDS
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsBlockDescriptor
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kK3
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptorAsKT
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
// OGradT tile in LDS
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsBlockDescriptor
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kK1
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
constexpr
auto
gemm_2
=
Policy
::
template
GetOGradVBlockGemm
<
Problem
>();
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_q_end
-
seqlen_q_start
,
kM0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
}
auto
k_block_tile
=
load_tile
(
k_dram_window
);
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
auto
q_dram_block_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
qt_dram_block_window
=
make_tile_window
(
qt_dram_block_window_tmp
.
get_bottom_tensor_view
(),
qt_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
do_dram_block_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
dot_dram_block_window
=
make_tile_window
(
dot_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dot_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
qt_dram_window
=
make_tile_window
(
qt_dram_block_window
.
get_bottom_tensor_view
(),
qt_dram_block_window
.
get_window_lengths
(),
qt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQTDramTileDistribution
<
Problem
>());
auto
dot_dram_window
=
make_tile_window
(
dot_dram_block_window
.
get_bottom_tensor_view
(),
dot_dram_block_window
.
get_window_lengths
(),
dot_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradTDramTileDistribution
<
Problem
>());
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
lse_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
auto
q_block_tile
=
load_tile
(
q_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
clear_tile
(
st_acc
);
// Initialize S^T
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write 0
q_block_tile
=
load_tile
(
q_dram_window
);
// global read 1
}
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
});
}
const
auto
dot_prefetch
=
load_tile
(
dot_dram_window
);
// prefetch load OGrad^T tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kN0
,
(
k0_loops
-
1
)
*
kK0
>
{}));
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
store_tile
(
biast_lds_shuffle_window
,
bias_shuffle_tmp
);
block_sync_lds
();
auto
biast_tile
=
load_tile
(
biast_lds_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
st_acc
,
biast_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
st_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
st_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
auto
pt
=
SPTBlockTileType
{};
constexpr
auto
pt_spans
=
decltype
(
pt
)
::
get_distributed_spans
();
sweep_tile_span
(
pt_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
pt
(
i_j_idx
)
=
exp2
(
st_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
pt
(
i_j_idx
)
=
exp2
(
scale
*
st_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
auto
dot_shuffle_tmp
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
dot_shuffle_tmp
,
dot_prefetch
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_start
+
i_total_loops
*
kM0
,
pt
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
pt
);
}
}();
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
dot
=
load_tile
(
dot_dram_window
);
// load next OGrad^T
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
shuffle_tile
(
dot_shuffle_tmp
,
dot
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
});
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
// tail
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
(
k1_loops
-
1
)
*
kK1
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
}
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
{
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write 0
do_block_tile
=
load_tile
(
do_dram_window
);
// global read 1
}
if
constexpr
(
k2_loops
>
2
)
{
static_for
<
0
,
k2_loops
-
2
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write i + 1
do_block_tile
=
load_tile
(
do_dram_window
);
// global read i + 2
});
}
const
auto
qt_prefetch
=
load_tile
(
qt_dram_window
);
// prefetch load Q^T tile
{
// tail
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
2
)
*
kK2
>
{},
sequence
<
kN0
,
(
k2_loops
-
1
)
*
kK2
>
{}));
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
1
)
*
kK2
>
{},
sequence
<
kN0
,
k2_loops
*
kK2
>
{}));
}
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
dst_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
pt
[
i_j_idx
]
>=
0
;
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbiast
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
dst
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
dst
);
}
}();
store_tile
(
biast_lds_shuffle_window
,
dbiast
);
block_sync_lds
();
auto
dbiast_tile
=
load_tile
(
dbiast_lds_shuffle_window
);
auto
dbiast_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_shuffle_tmp
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
qt_shuffle_tmp
,
qt_prefetch
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
if
constexpr
(
k3_loops
>
1
)
{
static_for
<
0
,
k3_loops
-
1
,
1
>
{}([
&
](
auto
i_k3
)
{
const
auto
qt
=
load_tile
(
qt_dram_window
);
// load next Q^T
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
shuffle_tile
(
qt_shuffle_tmp
,
qt
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
});
}
// tail
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
(
k3_loops
-
1
)
*
kK3
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
block_sync_lds
();
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
});
// QGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
// KGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
deleted
100644 → 0
View file @
c160c6cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
deleted
100644 → 0
View file @
c160c6cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, q & k & do located in lds.
using
BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
true
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
true
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
Prev
1
…
4
5
6
7
8
9
10
11
12
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment