Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f6ceef78
Commit
f6ceef78
authored
Aug 26, 2024
by
ThomasNing
Browse files
merge with the develop branch
parents
536c5458
25935b57
Changes
240
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3815 additions
and
1851 deletions
+3815
-1851
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
...e/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
+141
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
...ude/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
+3
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+782
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+1037
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
.../fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
+0
-821
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+1514
-922
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
...k_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
+2
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
...ile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
+37
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+3
-1
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+10
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+3
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+202
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
...gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
+36
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
...emm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
+33
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
+6
-9
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
+6
-9
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
0 → 100644
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdConvertQGrad
{
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
static
constexpr
index_t
kM0
=
Problem
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
kN0
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
kQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
static
constexpr
index_t
kAlignmentQGradAcc
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGradAcc
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGrad
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
// Convert only
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradDramTileDistribution
<
Problem
>());
auto
dq_acc
=
load_tile
(
dq_acc_dram_window
);
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
// Reduce + Convert
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
index_t
nsplits
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>());
auto
dq_acc
=
decltype
(
load_tile
(
dq_acc_dram_window
)){};
clear_tile
(
dq_acc
);
constexpr
auto
dq_acc_spans
=
decltype
(
dq_acc
)
::
get_distributed_spans
();
index_t
i_total_loops
=
0
;
auto
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
do
{
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
i_total_loops
+=
1
;
}
while
(
i_total_loops
<
(
nsplits
-
1
));
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
// declare dq
constexpr
auto
dq_converted_dstr
=
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>();
auto
dq_converted
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_converted_dstr
);
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_converted
(
n_i_j_idx
)
=
type_convert
<
QGradDataType
>
(
dq_acc
[
n_i_j_idx
]);
});
});
});
constexpr
auto
dq_dstr
=
Policy
::
template
MakePostQGradDramTileDistribution
<
Problem
>();
auto
dq
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_dstr
);
dq
.
get_thread_buffer
()
=
dq_converted
.
get_thread_buffer
();
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
View file @
f6ceef78
...
...
@@ -4,11 +4,11 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dot_do_o
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
OGradDotO
DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
Pipeline
DefaultPolicy
>
struct
BlockFmhaBwdOGradDotO
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
...
...
@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
Grad
<
Problem
>();
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
deleted
100644 → 0
View file @
536c5458
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// These templates are not used here.
using
BlockFmhaBwdOGradDotODefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
false
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
false
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_
qs_ks_vr_dos
.hpp
→
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_
kr_ktr_vr
.hpp
View file @
f6ceef78
...
...
@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dq_dk_dv_pipeline_qs_ks_vr_dos
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
DQDKDV
Pipeline
QSKSVROGradS
DefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipeline
QSKSVROGradS
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipeline
KRKTRVR
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
...
...
@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
HotLoopScheduler
=
typename
Policy
::
template
HotLoopScheduler
<
Problem
>;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
...
@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
true
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
false
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
true
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
1
;
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
...
...
@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"
qs_ks_vr_dos
"
;
static
constexpr
const
char
*
name
=
"
kr_ktr_vr
"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
...
...
@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
/*qt_dram_block_window_tmp*/
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
/*kt_dram_block_window_tmp*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
/*dot_dram_block_window_tmp*/
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
...
...
@@ -122,13 +107,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -138,9 +121,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
...
...
@@ -156,77 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// QT tile in LDS
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptorAsQT
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptorAsKT
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
{
0
,
0
});
// OGradT tile in LDS
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptorAsOGradT
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
...
...
@@ -234,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
// K, HBM ->LDS ->Reg
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
// Early termination
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
...
...
@@ -274,217 +169,408 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
}
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKRegSliceBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
VDataType
*
v_lds_ptr
=
static_cast
<
VDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto
shuffled_k_block_tile
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKRegWriteBlockDescriptor
<
Problem
>());
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
shuffled_k_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
auto
kt_lds_read_window
=
make_tile_window
(
kt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeKTRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
v_block_tile
=
load_tile
(
v_dram_window
);
store_tile
(
k_lds_write_window
,
k_block_tile
);
shuffle_tile
(
shuffled_k_block_tile
,
k_block_tile
);
store_tile
(
shuffled_k_lds_write_window
,
shuffled_k_block_tile
);
block_sync_lds
();
k_reg_tensor
=
load_tile
(
k_lds_read_window
);
block_sync_lds
();
auto
kt_reg_tensor
=
load_tile
(
kt_lds_read_window
);
store_tile
(
k
_lds_window
,
k
_block_tile
);
// // persistent K in LDS
store_tile
(
v
_lds_
write_
window
,
v
_block_tile
);
auto
q_dram_block_window
=
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
block_sync_lds
();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
q_lds_window
.
get_window_origin
(),
Policy
::
template
MakeQRegSliceBlockDescriptor
<
Problem
>());
auto
pt_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakePTRegSliceBlockDescriptor
<
Problem
>());
// QT: Reg -> Reg-> LDS
auto
shuffled_q_block_tile
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQRegWriteBlockDescriptor
<
Problem
>());
auto
do_dram_block_window
=
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
shuffled_q_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
auto
qt_lds_read_window
=
make_tile_window
(
qt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeQTRegSliceBlockDescriptor
<
Problem
>());
// dO: HBM ->Reg ->LDS
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
do_lds_window
.
get_window_origin
(),
Policy
::
template
MakeOGradRegSliceBlockDescriptor
<
Problem
>());
// dOT: Reg ->Reg ->LDS
auto
shuffled_do_block_tile
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradRegWriteBlockDescriptor
<
Problem
>());
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
shuffled_do_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
auto
dot_lds_read_window
=
make_tile_window
(
dot_read_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeOGradTRegSliceBlockDescriptor
<
Problem
>());
// dS: Reg -> Reg -> LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
ds_lds_read_window
=
make_tile_window
(
ds_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK4
>
{}),
ds_lds_window
.
get_window_origin
(),
Policy
::
template
MakeSGradRegSliceBlockDescriptor
<
Problem
>());
auto
dst_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakeSGradTRegSliceBlockDescriptor
<
Problem
>());
// Bias: HBM ->Reg ->Reg ->LDS
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})},
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
BiasDataType
*
bias_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
bias_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
bias_lds_ptr
,
Policy
::
template
MakeBiasLdsBlockDescriptor
<
Problem
>());
auto
bias_lds_write_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
bias_s_lds_read_window
=
make_tile_window
(
bias_lds_write_window
.
get_bottom_tensor_view
(),
bias_lds_write_window
.
get_window_lengths
(),
bias_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeBiasSTileDistribution
<
decltype
(
gemm_0
)>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// LSE: HBM -> LDS ->Reg
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
l
se
_dram_block_window
.
get_window_origin
()
,
lse_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window
_tmp
.
get_window_lengths
(),
{
se
qlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
LSEDataType
*
lse_lds_ptr
=
static_cast
<
LSEDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
lse_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
lse_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_lds_write_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
lse_lds_read_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// D: HBM ->Reg
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
()
,
d_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window
_tmp
.
get_window_lengths
(),
{
seqlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
DDataType
*
d_lds_ptr
=
static_cast
<
DDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()));
auto
d_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
d_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
d_lds_write_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
()
,
biast_lds_shuffle_window
.
get_window_lengths
(
),
biast_lds_shuffle_window
.
get_window_origin
()
,
Policy
::
template
Make
BiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
d_lds_read_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}
),
{
0
}
,
Policy
::
template
Make
LSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
// RandVal: HBM ->Reg
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
dbias_lds_read_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
// ----------------------------Loop write out------------------------------//
auto
dq_dram_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
using
SPBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
==
kK0
,
"kQKHeaddim should equal to kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
==
kK2
,
"kVHeaddim should equal to kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{}
;
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
)
;
__builtin_amdgcn_sched_barrier
(
0
);
// Hot loop
while
(
i_total_loops
<
num_total_loop
)
{
auto
q_block_tile
=
load_tile
(
q_dram_window
);
clear_tile
(
st_acc
);
// Initialize S^T
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
get_slice_tile
(
q_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
});
}
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
get_slice_tile
(
q_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
block_sync_lds
();
}
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
block_sync_lds
();
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
block_sync_lds
();
// STAGE 1, Q@K Gemm0
auto
s_acc
=
SPBlockTileType
{};
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
(
);
auto
bias_
shuffle
_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffle
d_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_
shuffle
_tmp
,
bias_tile
);
store_tile
(
bias
t
_lds_
shuffl
e_window
,
bias_
shuffle
_tmp
);
shuffle_tile
(
shuffle
d_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_
writ
e_window
,
shuffle
d_bias_tile
);
block_sync_lds
();
auto
bias
t
_tile
=
load_tile
(
bias
t
_lds_window
);
auto
bias
_s
_tile
=
load_tile
(
bias
_s
_lds_
read_
window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
s
t
_acc
,
bias
t
_tile
);
s_acc
,
bias
_s
_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s
t
_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s
t
_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
...
...
@@ -499,157 +585,162 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
}
};
auto
p
t
=
SP
T
BlockTileType
{};
constexpr
auto
p
t
_spans
=
decltype
(
p
t
)
::
get_distributed_spans
();
sweep_tile_span
(
p
t
_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
t
(
i_j_idx
)
=
exp2
(
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
t
(
i_j_idx
)
=
exp2
(
scale
*
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_st
art
+
i_total_loops
*
kM0
,
p
t
,
randval_dram_window
);
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_st
ep
,
k_origin
.
at
(
number
<
0
>
{})
,
p
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
// store the prefetch
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
t
);
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
t
);
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
static_for
<
0
,
k1_loops
,
1
>
{}([
&
](
auto
i_k1
)
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
get_slice_tile
(
dot_lds_window
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kVHeaddim
,
(
i_k1
+
1
)
*
kK1
>
{}));
block_sync_lds
();
});
// STAGE 3, P^T@OGrad^T Gemm1
auto
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
store_tile
(
d_lds_write_window
,
d_block_tile
);
block_sync_lds
();
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
block_sync_lds
();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
block_sync_lds
();
static_for
<
0
,
k2_loops
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
get_slice_tile
(
do_lds_window
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kM0
,
(
i_k2
+
1
)
*
kK2
>
{}),
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
});
auto
dp_acc
=
SPGradBlockTileType
{};
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds
t
_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
t
[
i_j_idx
]
>=
0
;
ds
t
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
t
=
[
&
]()
{
if
constexpr
(
kHa
sDropout
)
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
t
);
ds
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
t
);
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias
t
_lds_
shuffl
e_window
,
dbias
t
);
store_tile
(
bias_lds_
writ
e_window
,
dbias
);
block_sync_lds
();
auto
dbias
t
_tile
=
load_tile
(
dbias
t
_lds_
shuffle
_window
);
auto
dbias
t_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
auto
shuffled_
dbias_tile
=
load_tile
(
dbias_lds_
read
_window
);
auto
dbias
_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
block_sync_lds
();
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
static_for
<
0
,
k3_loops
,
1
>
{}([
&
](
auto
i_k3
)
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
get_slice_tile
(
qt_lds_window
,
sequence
<
0
,
i_k3
*
kK3
>
{},
sequence
<
kQKHeaddim
,
(
i_k3
+
1
)
*
kK3
>
{}));
block_sync_lds
();
});
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
// STAGE7 SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
});
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
// QGrad Scale
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
...
...
@@ -658,34 +749,33 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
if
constexpr
(
kIsDeterministic
)
{
store_tile
(
dq_dram_window
,
dq_acc
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
}
//
KGrad
Scale
if
constexpr
(
kHa
sDropout
)
//
Results
Scale
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
s
_kt
s
_vr.hpp
→
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
r
_kt
r
_vr
_iglp
.hpp
View file @
f6ceef78
...
...
@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dq_dk_dv_pipeline_ks_kts_vr
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
DQDKDV
Pipeline
KSKTSVR
DefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineK
S
KT
S
VR
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineK
R
KT
R
VR
IGLP
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
...
...
@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
HotLoopScheduler
=
typename
Policy
::
template
HotLoopScheduler
<
Problem
>;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
...
@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
true
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
false
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
1
;
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
...
...
@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"k
s
_kt
s
_vr"
;
static
constexpr
const
char
*
name
=
"k
r
_kt
r
_vr
_iglp
"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
...
...
@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
qt_dram_block_window_tmp
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
kt_dram_block_window_tmp
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
dot_dram_block_window_tmp
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
...
...
@@ -122,43 +107,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
QTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
KTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kVHeaddim
==
OGradTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
...
...
@@ -166,83 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// QT tile in LDS
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsBlockDescriptor
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kK3
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsBlockDescriptor
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
// OGradT tile in LDS
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsBlockDescriptor
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kK1
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
...
...
@@ -250,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
// K, HBM ->LDS ->Reg
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
// Early termination
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
...
...
@@ -290,272 +169,444 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
}
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKRegSliceBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
VDataType
*
v_lds_ptr
=
static_cast
<
VDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto
shuffled_k_block_tile
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKRegWriteBlockDescriptor
<
Problem
>());
auto
kt_dram_block_window
=
kt_dram_block_window_tmp
;
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
shuffled_k_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
auto
kt_lds_read_window
=
make_tile_window
(
kt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeKTRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
v_block_tile
=
load_tile
(
v_dram_window
);
auto
kt_dram_window
=
make_tile_window
(
kt_dram_block_window
.
get_bottom_tensor_view
(),
kt_dram_block_window
.
get_window_lengths
(),
kt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKTDramTileDistribution
<
Problem
>());
// K^T DRAM tile window for
// load
store_tile
(
k_lds_write_window
,
k_block_tile
);
shuffle_tile
(
shuffled_k_block_tile
,
k_block_tile
);
store_tile
(
shuffled_k_lds_write_window
,
shuffled_k_block_tile
);
auto
kt_block_tile
=
load_tile
(
kt_dram_window
);
block_sync_lds
();
k_reg_tensor
=
load_tile
(
k_lds_read_window
);
block_sync_lds
();
auto
kt_shuffle_tmp
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKTRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
kt_shuffle_tmp
,
kt_block_tile
);
auto
kt_reg_tensor
=
load_tile
(
kt_lds_read_window
);
store_tile
(
kt
_lds_window
,
kt_shuffle_tmp
);
// persistent K^T in LDS
store_tile
(
v
_lds_
write_
window
,
v_block_tile
);
auto
q_dram_block_window
=
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
q_lds_window
.
get_window_origin
(),
Policy
::
template
MakeQRegSliceBlockDescriptor
<
Problem
>());
auto
qt_dram_block_window
=
make_tile_window
(
qt_dram_block_window_tmp
.
get_bottom_tensor_view
(),
qt_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
pt_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakePTRegSliceBlockDescriptor
<
Problem
>());
// QT: Reg -> Reg-> LDS
auto
shuffled_q_block_tile
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQRegWriteBlockDescriptor
<
Problem
>());
auto
do_dram_block_window
=
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
shuffled_q_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
auto
qt_lds_read_window
=
make_tile_window
(
qt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeQTRegSliceBlockDescriptor
<
Problem
>());
// dO: HBM ->Reg ->LDS
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
auto
dot_dram_block_window
=
make_tile_window
(
dot_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dot_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
do_lds_window
.
get_window_origin
(),
Policy
::
template
MakeOGradRegSliceBlockDescriptor
<
Problem
>());
// dOT: Reg ->Reg ->LDS
auto
shuffled_do_block_tile
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradRegWriteBlockDescriptor
<
Problem
>());
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
auto
shuffled_do_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
auto
dot_lds_read_window
=
make_tile_window
(
dot_read_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeOGradTRegSliceBlockDescriptor
<
Problem
>());
// dS: Reg -> Reg -> LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
ds_lds_read_window
=
make_tile_window
(
ds_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK4
>
{}),
ds_lds_window
.
get_window_origin
(),
Policy
::
template
MakeSGradRegSliceBlockDescriptor
<
Problem
>());
auto
dst_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakeSGradTRegSliceBlockDescriptor
<
Problem
>());
// Bias: HBM ->Reg ->Reg ->LDS
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})},
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
BiasDataType
*
bias_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
bias_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
bias_lds_ptr
,
Policy
::
template
MakeBiasLdsBlockDescriptor
<
Problem
>());
auto
qt_dram_window
=
make_tile_window
(
qt_dram_block_window
.
get_bottom_tensor_view
(),
qt_dram_block_window
.
get_window_lengths
(),
qt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQTDramTileDistribution
<
Problem
>());
auto
bias_lds_write_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dot_dram
_window
=
make_tile_window
(
dot_dram_block
_window
.
get_bottom_tensor_view
(),
dot_dram_block
_window
.
get_window_lengths
(),
dot_dram_block
_window
.
get_window_origin
(),
Policy
::
template
Make
OGradTDram
TileDistribution
<
Problem
>());
auto
bias_s_lds_read
_window
=
make_tile_window
(
bias_lds_write
_window
.
get_bottom_tensor_view
(),
bias_lds_write
_window
.
get_window_lengths
(),
bias_lds_write
_window
.
get_window_origin
(),
Policy
::
template
Make
BiasS
TileDistribution
<
decltype
(
gemm_0
)
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// LSE: HBM -> LDS ->Reg
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
l
se
_dram_block_window
.
get_window_origin
()
,
lse_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window
_tmp
.
get_window_lengths
(),
{
se
qlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
LSEDataType
*
lse_lds_ptr
=
static_cast
<
LSEDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
lse_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
lse_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_lds_write_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
lse_lds_read_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// D: HBM ->Reg
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
()
,
d_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window
_tmp
.
get_window_lengths
(),
{
seqlen_q_start
}
,
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
DDataType
*
d_lds_ptr
=
static_cast
<
DDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>())
)
;
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
d_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
d_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
auto
d_lds_write_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
d_lds_read_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// RandVal: HBM ->Reg
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
dbias_lds_read_window
=
make_tile_window
(
bias_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
// ----------------------------Loop write out------------------------------//
auto
dq_dram_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
using
SPBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
==
kK0
,
"kQKHeaddim should equal to kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
==
kK2
,
"kVHeaddim should equal to kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
/*
* Prefetch Q, LSE, dO, D
*/
auto
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
auto
q_block_tile
=
load_tile
(
q_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
auto
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
clear_tile
(
st_acc
);
// Initialize S^T
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write 0
q_block_tile
=
load_tile
(
q_dram_window
);
// global read 1
}
/*
* Store prefetched data into LDS
*/
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
});
}
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
const
auto
dot_prefetch
=
load_tile
(
dot_dram_window
);
// prefetch load OGrad^T tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kN0
,
(
k0_loops
-
1
)
*
kK0
>
{}));
block_sync_lds
();
store_tile
(
d_lds_write_window
,
d_block_tile
);
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
block_sync_lds
();
/*
* Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
}
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
__builtin_amdgcn_sched_barrier
(
0
);
// Hot loop
while
(
i_total_loops
<
(
num_total_loop
-
1
))
{
// STAGE 1, Q@K Gemm0
auto
s_acc
=
SPBlockTileType
{};
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
0
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
(
);
auto
bias_
shuffle
_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffle
d_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_
shuffle
_tmp
,
bias_tile
);
store_tile
(
bias
t
_lds_
shuffl
e_window
,
bias_
shuffle
_tmp
);
shuffle_tile
(
shuffle
d_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_
writ
e_window
,
shuffle
d_bias_tile
);
block_sync_lds
();
auto
bias
t
_tile
=
load_tile
(
bias
t
_lds_window
);
auto
bias
_s
_tile
=
load_tile
(
bias
_s
_lds_
read_
window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
s
t
_acc
,
bias
t
_tile
);
s_acc
,
bias
_s
_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s
t
_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s
t
_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
...
...
@@ -570,278 +621,416 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
};
auto
p
t
=
SP
T
BlockTileType
{};
constexpr
auto
p
t
_spans
=
decltype
(
p
t
)
::
get_distributed_spans
();
sweep_tile_span
(
p
t
_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
t
(
i_j_idx
)
=
exp2
(
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
t
(
i_j_idx
)
=
exp2
(
scale
*
s
t
_acc
[
i_j_idx
]
-
row_lse
);
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
auto
dot_shuffle_tmp
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
dot_shuffle_tmp
,
dot_prefetch
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
if
constexpr
(
kHasDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_st
art
+
i_total_loops
*
kM0
,
p
t
,
randval_dram_window
);
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_st
ep
,
k_origin
.
at
(
number
<
0
>
{})
,
p
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
t
);
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
t
);
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
dot
=
load_tile
(
dot_dram_window
);
// load next OGrad^T
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
shuffle_tile
(
dot_shuffle_tmp
,
dot
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
});
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
// tail
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
(
k1_loops
-
1
)
*
kK1
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
}
// STAGE 3, P^T@OGrad^T Gemm1
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
{
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
clear_tile
(
dp
t
_acc
);
// Initialize PGrad^T
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write 0
do_block_tile
=
load_tile
(
do_dram_window
);
// global read 1
}
block_sync_lds
();
if
constexpr
(
k2_loops
>
2
)
{
static_for
<
0
,
k2_loops
-
2
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write i + 1
do_block_tile
=
load_tile
(
do_dram_window
);
// global read i + 2
});
}
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
shuffled_q_block_tile
,
q_block_tile
);
store_tile
(
shuffled_q_lds_write_window
,
shuffled_q_block_tile
);
const
auto
qt_prefetch
=
load_tile
(
qt_dram_window
);
// prefetch load Q^T tile
{
// tail
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
2
)
*
kK2
>
{},
sequence
<
kN0
,
(
k2_loops
-
1
)
*
kK2
>
{}));
block_sync_lds
();
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
store_tile
(
do_lds_window
,
do_block_tile
);
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
shuffled_do_block_tile
,
do_block_tile
);
store_tile
(
shuffled_do_lds_write_window
,
shuffled_do_block_tile
);
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
1
)
*
kK2
>
{},
sequence
<
kN0
,
k2_loops
*
kK2
>
{}));
}
store_tile
(
d_lds_write_window
,
d_block_tile
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds
t
_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
t
[
i_j_idx
]
>=
0
;
ds
t
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
t
=
[
&
]()
{
if
constexpr
(
kHa
sDropout
)
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
t
);
ds
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
t
);
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias
t
_lds_
shuffl
e_window
,
dbias
t
);
store_tile
(
bias_lds_
writ
e_window
,
dbias
);
block_sync_lds
();
auto
dbias
t
_tile
=
load_tile
(
dbias
t
_lds_
shuffle
_window
);
auto
dbias
t_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
auto
shuffled_
dbias_tile
=
load_tile
(
dbias_lds_
read
_window
);
auto
dbias
_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_shuffle_tmp
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQTRegBlockDescriptor
<
Problem
>());
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
lse
=
load_tile
(
lse_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE7 SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
d
=
load_tile
(
d_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
// QGrad Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
if
constexpr
(
kIsDeterministic
)
{
shuffle_tile
(
qt_shuffle_tmp
,
qt_prefetch
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
store_tile
(
dq_dram_window
,
dq_acc
);
}
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
}
__builtin_amdgcn_sched_barrier
(
0
);
// Tail
auto
s_acc
=
SPBlockTileType
{};
// STAGE 1, Q@K Gemm0
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
shuffled_bias_tile
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
shuffled_bias_tile
,
bias_tile
);
store_tile
(
bias_lds_write_window
,
shuffled_bias_tile
);
block_sync_lds
();
auto
bias_s_tile
=
load_tile
(
bias_s_lds_read_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
},
s_acc
,
bias_s_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
constexpr
(
k3_loops
>
1
)
s_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
static_for
<
0
,
k3_loops
-
1
,
1
>
{}([
&
](
auto
i_k3
)
{
const
auto
qt
=
load_tile
(
qt_dram_window
);
// load next Q^T
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
shuffle_tile
(
qt_shuffle_tmp
,
qt
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
set_tile_if
(
s_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
// tail
}
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
(
k3_loops
-
1
)
*
kK3
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
auto
p
=
SPBlockTileType
{};
constexpr
auto
p_spans
=
decltype
(
p
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p
(
i_j_idx
)
=
exp2
(
s_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
p
(
i_j_idx
)
=
exp2
(
scale
*
s_acc
[
i_j_idx
]
-
row_lse
);
}
});
});
block_sync_lds
();
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
p
,
randval_dram_window
);
}
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
p_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
p
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
p
);
}
}();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
constexpr
auto
ds_spans
=
decltype
(
ds
)
::
get_distributed_spans
();
sweep_tile_span
(
ds_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
ds_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
p
[
i_j_idx
]
>=
0
;
ds
(
i_j_idx
)
=
p
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dp_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
// QGrad Scale
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbias
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
ds
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
ds
);
}
}();
store_tile
(
bias_lds_write_window
,
dbias
);
block_sync_lds
();
auto
shuffled_dbias_tile
=
load_tile
(
dbias_lds_read_window
);
auto
dbias_tile
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
ds_gemm
=
cast_tile
<
GemmDataType
>
(
ds
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
ds_gemm
)>(
dst_reg_tensor
,
ds_gemm
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
ds_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
}
);
}
else
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
(
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
});
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
//
KGrad
Scale
if
constexpr
(
kHa
sDropout
)
//
Results
Scale
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsDeterministic
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
store_tile
(
dq_dram_window
,
dq_acc
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
deleted
100644 → 0
View file @
536c5458
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k & k^t located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
true
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
deleted
100644 → 0
View file @
536c5458
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineKSVR
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
using
KGradDataType
=
remove_cvref_t
<
typename
Problem
::
KGradDataType
>
;
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK2
=
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK3
=
BlockFmhaShape
::
kK3
;
static
constexpr
index_t
kK4
=
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
false
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
false
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentVGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"ks_vr"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
typename
BiasGradDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
qt_dram_block_window_tmp
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
/*kt_dram_block_window_tmp*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
dot_dram_block_window_tmp
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
const
BiasGradDramBlockWindowTmp
&
dbias_dram_block_window_tmp
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
QTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kVHeaddim
==
OGradTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// QT tile in LDS
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsBlockDescriptor
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kK3
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptorAsKT
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
// OGradT tile in LDS
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsBlockDescriptor
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kK1
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
constexpr
auto
gemm_2
=
Policy
::
template
GetOGradVBlockGemm
<
Problem
>();
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_q_end
-
seqlen_q_start
,
kM0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
}
auto
k_block_tile
=
load_tile
(
k_dram_window
);
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
auto
q_dram_block_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
qt_dram_block_window
=
make_tile_window
(
qt_dram_block_window_tmp
.
get_bottom_tensor_view
(),
qt_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
do_dram_block_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
dot_dram_block_window
=
make_tile_window
(
dot_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dot_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
qt_dram_window
=
make_tile_window
(
qt_dram_block_window
.
get_bottom_tensor_view
(),
qt_dram_block_window
.
get_window_lengths
(),
qt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQTDramTileDistribution
<
Problem
>());
auto
dot_dram_window
=
make_tile_window
(
dot_dram_block_window
.
get_bottom_tensor_view
(),
dot_dram_block_window
.
get_window_lengths
(),
dot_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradTDramTileDistribution
<
Problem
>());
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
lse_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
auto
q_block_tile
=
load_tile
(
q_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
clear_tile
(
st_acc
);
// Initialize S^T
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write 0
q_block_tile
=
load_tile
(
q_dram_window
);
// global read 1
}
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
});
}
const
auto
dot_prefetch
=
load_tile
(
dot_dram_window
);
// prefetch load OGrad^T tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kN0
,
(
k0_loops
-
1
)
*
kK0
>
{}));
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
store_tile
(
biast_lds_shuffle_window
,
bias_shuffle_tmp
);
block_sync_lds
();
auto
biast_tile
=
load_tile
(
biast_lds_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
st_acc
,
biast_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
st_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
st_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
auto
pt
=
SPTBlockTileType
{};
constexpr
auto
pt_spans
=
decltype
(
pt
)
::
get_distributed_spans
();
sweep_tile_span
(
pt_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
pt
(
i_j_idx
)
=
exp2
(
st_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
pt
(
i_j_idx
)
=
exp2
(
scale
*
st_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
auto
dot_shuffle_tmp
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
dot_shuffle_tmp
,
dot_prefetch
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_start
+
i_total_loops
*
kM0
,
pt
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
pt
);
}
}();
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
dot
=
load_tile
(
dot_dram_window
);
// load next OGrad^T
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
shuffle_tile
(
dot_shuffle_tmp
,
dot
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
});
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
// tail
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
(
k1_loops
-
1
)
*
kK1
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
}
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
{
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write 0
do_block_tile
=
load_tile
(
do_dram_window
);
// global read 1
}
if
constexpr
(
k2_loops
>
2
)
{
static_for
<
0
,
k2_loops
-
2
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write i + 1
do_block_tile
=
load_tile
(
do_dram_window
);
// global read i + 2
});
}
const
auto
qt_prefetch
=
load_tile
(
qt_dram_window
);
// prefetch load Q^T tile
{
// tail
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
2
)
*
kK2
>
{},
sequence
<
kN0
,
(
k2_loops
-
1
)
*
kK2
>
{}));
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
1
)
*
kK2
>
{},
sequence
<
kN0
,
k2_loops
*
kK2
>
{}));
}
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
dst_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
pt
[
i_j_idx
]
>=
0
;
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbiast
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
dst
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
dst
);
}
}();
store_tile
(
biast_lds_shuffle_window
,
dbiast
);
block_sync_lds
();
auto
dbiast_tile
=
load_tile
(
dbiast_lds_shuffle_window
);
auto
dbiast_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_shuffle_tmp
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
qt_shuffle_tmp
,
qt_prefetch
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
if
constexpr
(
k3_loops
>
1
)
{
static_for
<
0
,
k3_loops
-
1
,
1
>
{}([
&
](
auto
i_k3
)
{
const
auto
qt
=
load_tile
(
qt_dram_window
);
// load next Q^T
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
shuffle_tile
(
qt_shuffle_tmp
,
qt
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
});
}
// tail
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
(
k3_loops
-
1
)
*
kK3
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
block_sync_lds
();
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
});
// QGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
// KGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
deleted
100644 → 0
View file @
536c5458
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
deleted
100644 → 0
View file @
536c5458
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, q & k & do located in lds.
using
BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
true
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
true
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
f6ceef78
...
...
@@ -11,6 +11,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -18,60 +20,215 @@
namespace
ck_tile
{
template
<
bool
QLoadOnce_
,
bool
QTLoadOnce_
,
bool
KLoadOnce_
,
bool
KTLoadOnce_
,
bool
VLoadOnce_
,
bool
OGradLoadOnce_
,
bool
OGradTLoadOnce_
>
struct
BlockFmhaBwdPipelineDefaultPolicy
{
static
constexpr
bool
QLoadOnce
=
QLoadOnce_
;
// if q load whole block length (qkhdim) to LDS at once
static
constexpr
bool
QTLoadOnce
=
QTLoadOnce_
;
// if q^t load whole block length (qkhdim) to LDS at once
static
constexpr
bool
KLoadOnce
=
KLoadOnce_
;
// if k load whole block length (qkhdim) to LDS at once
static
constexpr
bool
KTLoadOnce
=
KTLoadOnce_
;
// if k^t load whole block length (qkhdim) to LDS at once
static
constexpr
bool
VLoadOnce
=
VLoadOnce_
;
// if v load whole block length (vhdim) to Vgprs at once
static
constexpr
bool
OGradLoadOnce
=
OGradLoadOnce_
;
// if do load whole block length (vhdim) to LDS at once
static
constexpr
bool
OGradTLoadOnce
=
OGradTLoadOnce_
;
// if do^t load whole block length (vhdim) to LDS at once
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
2
>
{}),
false
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{})
==
16
?
false
:
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
2
>
{}),
false
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{})
==
16
?
false
:
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
false
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// these are for global load
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentV
()
{
if
constexpr
(
VLoadOnce
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
}
else
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
total_pixels
>
kMaxVecLoad
?
kMaxVecLoad
:
total_pixels
;
}
template
<
typename
Problem
>
...
...
@@ -84,20 +241,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOGrad
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
QGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
Bias
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
BiasDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
...
...
@@ -128,60 +304,35 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentQ
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentK
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentOGrad
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
...
...
@@ -193,1136 +344,1577 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
32
)
return
8
;
else
return
4
;
return
total_pixels
/
GetAlignmentBias
<
Problem
>
();
}
// these are for lds
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
SmemKPackQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
AlignmentPostQGradAcc
()
{
// TODO: this is for 3d layout
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
16
/
sizeof
(
AccDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
SmemKPackK
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
AlignmentPostQGrad
()
{
// TODO: this is for 3d layout
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
return
GetAlignmentPostQGradAcc
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackV
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
// TODO: this is for 3d layout
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
// TODO: this is for 3d layout
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
return
16
/
sizeof
(
BiasDataType
);
}
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
{
// TODO: this is for 3d layout
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N2
=
kNPerBlock
/
(
N1
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackSGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
{
// TODO: this is for 3d layout
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
return
16
/
sizeof
(
GemmDataType
);
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVInRegDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
auto
v_block_outer
_dstr
_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
N
IterPerWarp
,
NWarp
>
,
sequence
<
K
IterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
return
make_static_tile
_d
i
str
ibution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N
0
,
N1
,
N2
>
,
sequence
<
K
0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
sequence
<
0
,
1
>>
{});
}
// 3d + padding
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
MNPerBlock
),
make_merge_transform
(
make_tuple
(
KPerBlock
/
KPack
,
KPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
return
x_lds_block_desc
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M1
=
get_warp_size
()
/
K0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
// 3d + padding
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptorAsXT
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradDramTileDistribution
()
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
MNPerBlock
),
make_merge_transform
(
make_tuple
(
KPerBlock
/
KPack
,
KPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
return
xt_lds_block_desc
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M1
=
get_warp_size
()
/
K0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
,
index_t
PixelsPerRow
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
XTLdsBlockDescriptor
()
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LSEDDramTileDistribution
()
{
static_assert
(
PixelsPerRow
%
KPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
KPack
;
static_assert
(
MNPerBlock
%
NPerRow
==
0
);
static_assert
(
KPerBlock
%
KPack
==
0
);
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
auto
xt_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
KPack
)
>
{},
number
<
PixelsPerRow
+
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
xt_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
MNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
// Duplicate dimension
constexpr
index_t
N0
=
NWarp
;
constexpr
index_t
N1
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
(
get_warp_size
()
/
kMPerBlock
)
:
1
;
return
xt_lds_block_desc
;
constexpr
index_t
M0
=
MWarp
;
constexpr
index_t
M1
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
kMPerBlock
:
get_warp_size
();
constexpr
index_t
M2
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
1
:
(
kMPerBlock
/
get_warp_size
());
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
>
,
sequence
<
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
BiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
constexpr
index_t
N1
=
GetAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M1
=
get_warp_size
()
/
N0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptorAsQT
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
template
<
typename
DataType
,
index_t
MPerBlock
,
index_t
KPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePreXDramTileDistribution
()
{
constexpr
index_t
K1
=
16
/
sizeof
(
DataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
1
;
constexpr
index_t
M1
=
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
return
MakeXLdsBlockDescriptorAsXT
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreODramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
return
MakePreXDramTileDistribution
<
ODataType
,
kBlockSize
,
kKPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KLdsBlockDescriptorAsKT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreOGradDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
return
Make
XLdsBlockDescriptorAsXT
<
kNPer
Block
,
kKPerBlock
,
kKPack
>
();
return
Make
PreXDramTileDistribution
<
OGradDataType
,
k
Block
Size
,
kKPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PostQGradAccDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kQKHeaddim
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
M2
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
1
>
,
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
3
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
OGradLdsBlockDescriptorAsOGradT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PostQGradDramTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
MakeXLdsBlockDescriptorAsXT
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kQKHeaddim
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSGradLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackSGrad
<
Problem
>
();
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
M2
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// these are for lds
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
QDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
GetAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQT
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
KDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
GetTransposedAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackK
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
GetAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackKT
()
{
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kMPerBlock
%
kKPack
==
0
);
constexpr
auto
biast_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
biast_lds_block_desc
=
transform_tensor_descriptor
(
biast_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
biast_lds_block_desc
;
return
GetTransposedAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
SizeQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPackV
()
{
constexpr
index_t
smem_size_q
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_q
;
return
GetAlignmentV
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
constexpr
index_t
smem_size_qt
=
[
&
]()
{
if
constexpr
(
QLoadOnce
&&
!
QTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_qt
;
return
GetAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBiasT
()
{
constexpr
index_t
smem_size_k
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_k
;
return
GetTransposedAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeKT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
{
constexpr
index_t
smem_size_kt
=
[
&
]()
{
if
constexpr
(
KLoadOnce
&&
!
KTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_kt
;
return
GetAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeV
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGradT
()
{
constexpr
index_t
smem_size_v
=
[
&
]()
{
if
constexpr
(
VLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
VDataType
)
*
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_v
;
return
GetTransposedAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
SizeO
Grad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPackS
Grad
()
{
constexpr
index_t
smem_size_do
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_do
;
// TODO: this is for 3d layout
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
return
16
/
sizeof
(
GemmDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeOGradT
()
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
{
constexpr
index_t
smem_size_dot
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
&&
!
OGradTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_dot
;
constexpr
auto
DataTypeSize
=
2
;
// sizeof(F16/BF16)
constexpr
auto
MNLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{},
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MNLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
x_lds_block_desc_permuted
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
x_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
x_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MNPerBlock
/
MNLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
MNLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
x_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeSGrad
()
template
<
typename
Problem
,
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
,
index_t
KPackT
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXTLdsBlockDescriptor
()
{
constexpr
index_t
smem_size_ds
=
sizeof
(
typename
Problem
::
GemmDataType
)
*
MakeSGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_ds
;
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
MNPerXDL
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
constexpr
auto
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
auto
MN0
=
MNPerBlock
/
KPack
;
constexpr
auto
MN1
=
KPack
;
constexpr
auto
KThreadWrite
=
kBlockSize
/
MN0
;
constexpr
auto
K0Number
=
KPerBlock
/
KPackT
;
constexpr
auto
K0PerThreadWrite
=
K0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
get_warp_size
()
/
MNPerXDL
;
// assume 32x32x8 mfma
constexpr
auto
K0PerThreadRead
=
K0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
KPackT
*
MN0
*
2
>
128
)
?
1
:
128
/
(
KPackT
*
MN0
*
2
);
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mnpair<=n0
constexpr
auto
mnpair
=
(
KPackT
*
MNPerXDL
*
2
>
128
)
?
1
:
((
128
/
(
KPackT
*
MNPerXDL
*
2
))
>
MN0
?
MN0
:
128
/
(
KPackT
*
MNPerXDL
*
2
));
constexpr
auto
xt_lds_block_desc_raw
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
MN1
>
{},
number
<
kfold
*
MN0
/
mnpair
>
{},
number
<
mnpair
>
{},
KPackT
),
make_tuple
(
number
<
KPackT
*
kfold
*
MN0
*
KThreadReadPerm
*
MN1
*
K0PerThreadWrite
>
{},
number
<
KPackT
*
kfold
*
MN0
*
KThreadReadPerm
*
MN1
>
{},
number
<
KPackT
*
kfold
*
MN0
>
{},
number
<
KPackT
*
mnpair
>
{},
number
<
KPackT
>
{},
number
<
1
>
{}),
number
<
KPackT
>
{},
number
<
1
>
{});
constexpr
auto
xt_lds_block_desc_permuted
=
transform_tensor_descriptor
(
xt_lds_block_desc_raw
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
MN1
>
{},
number
<
kfold
*
MN0
/
mnpair
>
{})),
make_pass_through_transform
(
number
<
mnpair
>
{}),
make_pass_through_transform
(
KPackT
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
xt_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
xt_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
MN1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
MN0
/
mnpair
>
{})),
make_pass_through_transform
(
number
<
mnpair
>
{}),
make_pass_through_transform
(
KPackT
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
xt_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KPackT
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MN0
/
mnpair
>
{},
number
<
mnpair
>
{},
number
<
MN1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
xt_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeBias
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
{
constexpr
index_t
smem_size_bias
=
[
&
]()
{
if
constexpr
(
Problem
::
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
sizeof
(
typename
Problem
::
BiasDataType
)
*
MakeBiasTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
else
return
0
;
}();
return
smem_size_bias
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_q
=
GetSmemSizeQ
<
Problem
>
();
constexpr
index_t
smem_size_qt
=
GetSmemSizeQT
<
Problem
>
();
constexpr
index_t
smem_size_k
=
GetSmemSizeK
<
Problem
>
();
constexpr
index_t
smem_size_kt
=
GetSmemSizeKT
<
Problem
>
();
constexpr
index_t
smem_size_v
=
GetSmemSizeV
<
Problem
>
();
constexpr
index_t
smem_size_do
=
GetSmemSizeOGrad
<
Problem
>
();
constexpr
index_t
smem_size_dot
=
GetSmemSizeOGradT
<
Problem
>
();
constexpr
index_t
smem_size_ds
=
GetSmemSizeSGrad
<
Problem
>
();
constexpr
index_t
smem_size_bias
=
GetSmemSizeBias
<
Problem
>
();
constexpr
index_t
smem_size_transpose
=
max
(
smem_size_ds
,
smem_size_bias
);
index_t
smem_size
=
0
;
if
constexpr
(
QLoadOnce
&&
OGradLoadOnce
)
smem_size
+=
smem_size_q
+
smem_size_qt
+
smem_size_do
+
smem_size_dot
+
smem_size_transpose
;
// 1~4 & 10
else
if
(
QLoadOnce
&&
!
OGradLoadOnce
&&
!
OGradTLoadOnce
)
smem_size
+=
smem_size_q
+
smem_size_qt
+
max
(
smem_size_do
,
smem_size_dot
,
smem_size_transpose
);
// 5/7/11 TODO: Multiple buffers strategy
else
if
(
!
QLoadOnce
&&
!
QTLoadOnce
&&
OGradLoadOnce
)
smem_size
+=
smem_size_do
+
smem_size_dot
+
max
(
smem_size_q
,
smem_size_qt
,
smem_size_transpose
);
// 6/8/12 TODO: Multiple buffers strategy
else
if
(
!
QLoadOnce
&&
!
QTLoadOnce
&&
!
OGradLoadOnce
&&
!
OGradTLoadOnce
)
smem_size
+=
max
(
smem_size_q
,
smem_size_qt
,
smem_size_do
,
smem_size_dot
,
smem_size_transpose
);
// 9/13 TODO: Multiple buffers strategy
// 14/15 needs to be adjusted
if
constexpr
(
KLoadOnce
)
smem_size
+=
(
smem_size_k
+
smem_size_kt
);
// 1~13
else
smem_size
=
max
(
smem_size_k
,
smem_size_kt
,
smem_size
);
// 14/15 TODO: Multiple buffers strategy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
max
(
smem_size
,
smem_size_v
);
// 15
}
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEDDramTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
index_t
N1
=
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
;
constexpr
index_t
N0
=
NWarp
;
constexpr
auto
k_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
index_t
M4
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
*
2
;
constexpr
index_t
M3
=
WG
::
WarpGemmAttribute
::
Impl
::
kCMLane
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM0PerLane
/
2
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
WG
::
WarpGemmAttribute
::
Impl
::
kM
);
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
,
M4
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
3
,
1
>>
,
sequence
<
1
,
1
,
1
>
,
sequence
<
0
,
2
,
4
>>
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KRegBlockDescriptor
()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
return
make_static_tile
_d
i
str
ibution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N
0
,
N1
,
N2
>
,
sequence
<
K
0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
auto
k_block_outer
_dstr
_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
N
IterPerWarp
,
NWarp
>
,
sequence
<
K
IterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VLdsWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VRegSliceBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledKRegWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
N2
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledKLdsWriteBlockDescriptor
()
{
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackKT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKTLdsReadBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
auto
shuffled_k_lds_block_desc
=
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>
();
return
transform_tensor_descriptor
(
shuffled_k_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKTRegBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
kt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
kt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
kt_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
kt_block_dstr
=
make_static_tile_distribution
(
kt_block_dstr_encode
);
return
kt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
OGradDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QLdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
return
make_static_tile
_d
i
str
ibution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M
0
,
M1
,
M2
>
,
sequence
<
K
0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
auto
q_block_outer
_dstr
_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
M
IterPerWarp
,
MWarp
>
,
sequence
<
K
IterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
q_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
q_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
q_block_dstr
=
make_static_tile_distribution
(
q_block_dstr_encode
);
return
q_block_dstr
;
}
template
<
typename
DataType
,
index_t
MPerBlock
,
index_t
KPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreXDramTileDistribution
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledQRegWriteBlockDescriptor
()
{
constexpr
index_t
K1
=
16
/
sizeof
(
DataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
1
;
constexpr
index_t
M1
=
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M
0
,
M
1
,
M
2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
0
,
1
>>
{});
tuple
<
sequence
<
N
0
,
N
1
,
N
2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreODramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledQLdsWriteBlockDescriptor
()
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
k
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKP
erBlo
ck
=
Problem
::
kVHeaddim
;
constexpr
index_t
k
KPack
=
GetSmemKPackQ
<
Problem
>
()
;
constexpr
index_t
kKP
a
ck
T
=
GetSmemKPackQT
<
Problem
>
()
;
return
Make
PreXDramTileDistribution
<
ODataType
,
k
Block
Size
,
kKPerBlock
>
();
return
Make
XTLdsBlockDescriptor
<
Problem
,
kNPer
Block
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreOGradDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QTLdsReadBlockDescriptor
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
auto
shuffled_q_lds_block_desc
=
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>
();
return
MakePreXDramTileDistribution
<
OGradDataType
,
kBlockSize
,
kKPerBlock
>
();
return
transform_tensor_descriptor
(
shuffled_q_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeQT
DramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
MakeQT
RegSliceBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
N
1
=
GetTransposedAlignmentQ
<
Problem
>
(
);
constexpr
index_t
N0
=
k
N
PerBlock
/
N1
;
// P
constexpr
index_t
N
IterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
k
K
PerBlock
/
WarpGemm
::
kK
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
constexpr
auto
qt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
constexpr
auto
qt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
qt_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
qt_block_dstr
=
make_static_tile_distribution
(
qt_block_dstr_encode
);
return
qt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeS
huffledQTReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeS
GradTRegSlice
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
N1
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
1
>
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
dst_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
dst_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
dst_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
dst_block_dstr
=
make_static_tile_distribution
(
dst_block_dstr_encode
);
return
dst_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
KTDramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
Make
LSEDLdsWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
using
LSEDType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
constexpr
index_t
kMPack
=
16
/
sizeof
(
LSEDType
);
constexpr
index_t
N1
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
auto
lsed_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
>
{}),
make_tuple
(
number
<
1
>
{}),
number
<
kMPack
>
{},
number
<
1
>
{});
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
lsed_lds_block_desc
;
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEDLdsReadBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
;
constexpr
index_t
N0
=
NWarp
;
// M4 *2 and M2 /2 when swizzle mode enabled
constexpr
index_t
SwizzleConfig
=
WG
::
kM
==
16
?
1
:
2
;
// constexpr index_t SwizzleConfig = 1;
constexpr
index_t
M4
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
*
SwizzleConfig
;
constexpr
index_t
M3
=
WG
::
WarpGemmAttribute
::
Impl
::
kCMLane
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM0PerLane
/
SwizzleConfig
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
WG
::
WarpGemmAttribute
::
Impl
::
kM
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K
0
,
K
1
,
K
2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M
0
,
M
1
,
M
2
,
M3
,
M4
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
3
,
1
>>
,
sequence
<
1
,
1
,
1
>
,
sequence
<
0
,
2
,
4
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledKTReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
OGradLds
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
// Hold full block data
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
N1
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
do_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
do_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
do_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
do_block_dstr
=
make_static_tile_distribution
(
do_block_dstr_encode
);
return
do_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
OGradTDramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
Make
ShuffledOGradRegWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledOGrad
TReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledOGrad
LdsWrite
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackOGradT
<
Problem
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTLdsReadBlockDescriptor
()
{
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
auto
shuffled_do_lds_block_desc
=
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>
();
return
transform_tensor_descriptor
(
shuffled_do_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
// constexpr index_t kNPerBlock = 32;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
dot_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
dot_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
dot_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
dot_block_dstr
=
make_static_tile_distribution
(
dot_block_dstr_encode
);
return
dot_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeBiasTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
pt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
pt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
pt_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
pt_block_dstr
=
make_static_tile_distribution
(
pt_block_dstr_encode
);
return
pt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSGradLdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackSGrad
<
Problem
>
();
constexpr
index_t
N1
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kMPerBlock
==
M0
*
M1
*
M2
*
M3
);
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSGradRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK4
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
ds_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
3
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
ds_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
ds_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
ds_block_dstr
=
make_static_tile_distribution
(
ds_block_dstr_encode
);
return
ds_block_dstr
;
}
template
<
typename
Problem
,
typename
PTOutTensor
,
typename
PInTensor
>
CK_TILE_DEVICE
static
constexpr
void
PTFromGemm0CToGemm1A
(
PTOutTensor
&
pt_out
,
const
PInTensor
&
p_in
)
{
if
constexpr
(
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{})
==
16
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
auto
pt_warp_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
GemmDataType
>
(
CWarpDstr
{});
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
pt_warp_tensor
.
get_thread_buffer
()
=
p_in
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
kIter
,
mIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
pt_out
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
),
pt_warp_tensor
.
get_thread_buffer
());
});
});
}
else
{
pt_out
.
get_thread_buffer
()
=
p_in
.
get_thread_buffer
();
}
}
template
<
typename
Problem
,
typename
SGradTOutTensor
,
typename
SGradInTensor
>
CK_TILE_DEVICE
static
constexpr
void
SGradTFromGemm2CToGemm3A
(
SGradTOutTensor
&
dst_out
,
const
SGradInTensor
&
ds_in
)
{
if
constexpr
(
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{})
==
16
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
auto
dst_warp_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
GemmDataType
>
(
CWarpDstr
{});
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
dst_warp_tensor
.
get_thread_buffer
()
=
ds_in
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
kIter
,
mIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
dst_out
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
),
dst_warp_tensor
.
get_thread_buffer
());
});
});
}
else
{
dst_out
.
get_thread_buffer
()
=
ds_in
.
get_thread_buffer
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
N1
=
GetAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M2
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
M1
=
get_warp_size
()
/
N0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
3
>>
{});
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasLdsBlockDescriptor
()
{
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackBiasT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kMPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBias
T
TileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBias
S
TileDistribution
()
{
using
c_block_tensor_type
=
decltype
(
BlockGemm
{}.
MakeCBlockTile
());
return
c_block_tensor_type
::
get_tile_distribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeQ
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
{};
}
}();
constexpr
index_t
smem_size_q
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_q
;
}
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeQT
()
{
constexpr
index_t
smem_size_qt
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
()
;
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{}
;
return
smem_size_qt
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeK
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
Acc
DataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK1
>>
;
constexpr
index_t
smem_size_k
=
sizeof
(
typename
Problem
::
K
DataType
)
*
MakeKLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_k
;
}
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeKT
()
{
constexpr
index_t
smem_size_kt
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKTLdsReadBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_kt
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeLSE
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK2
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
{};
}
}();
constexpr
index_t
smem_size_lse
=
sizeof
(
typename
Problem
::
LSEDataType
)
*
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_lse
;
}
using
BlockGemmPolicy
=
BlockGemmASmemBRegCRegV1CustomPolicy
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
// {
// using BlockGemmProblem =
// BlockGemmPipelineProblem<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// Problem::kBlockSize,
// TileGemmShape<Problem::BlockFmhaShape::kM0,
// Problem::BlockFmhaShape::kN0,
// Problem::BlockFmhaShape::kK2>>;
// constexpr auto warp_gemm = []() {
// if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
// std::is_same_v<typename Problem::VDataType, half_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
// }
// else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
// std::is_same_v<typename Problem::VDataType, bf16_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
// }
// }();
// using BlockGemmPolicy =
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// typename
// Problem::BlockFmhaShape::Gemm2BlockWarps,
// decltype(warp_gemm)>;
// return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
// }
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeD
()
{
constexpr
index_t
smem_size_d
=
sizeof
(
typename
Problem
::
DDataType
)
*
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_d
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeV
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
Acc
DataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK3
>>
;
constexpr
index_t
smem_size_v
=
sizeof
(
typename
Problem
::
V
DataType
)
*
MakeVLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_v
;
}
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeOGrad
()
{
constexpr
index_t
smem_size_do
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_do
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeOGradT
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
Acc
DataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK4
>>
;
constexpr
index_t
smem_size_dot
=
sizeof
(
typename
Problem
::
OGrad
DataType
)
*
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_dot
;
}
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeSGrad
()
{
constexpr
index_t
smem_size_ds
=
sizeof
(
typename
Problem
::
GemmDataType
)
*
MakeSGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_ds
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeBias
()
{
constexpr
index_t
smem_size_bias
=
[
&
]()
{
if
constexpr
(
Problem
::
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
sizeof
(
typename
Problem
::
BiasDataType
)
*
MakeBiasLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
else
return
0
;
}();
return
smem_size_bias
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_q
=
GetSmemSizeQ
<
Problem
>
();
constexpr
index_t
smem_size_qt
=
GetSmemSizeQT
<
Problem
>
();
constexpr
index_t
smem_size_lse
=
GetSmemSizeLSE
<
Problem
>
();
constexpr
index_t
smem_size_k
=
GetSmemSizeK
<
Problem
>
();
constexpr
index_t
smem_size_kt
=
GetSmemSizeKT
<
Problem
>
();
constexpr
index_t
smem_size_v
=
GetSmemSizeV
<
Problem
>
();
constexpr
index_t
smem_size_do
=
GetSmemSizeOGrad
<
Problem
>
();
constexpr
index_t
smem_size_dot
=
GetSmemSizeOGradT
<
Problem
>
();
constexpr
index_t
smem_size_d
=
GetSmemSizeD
<
Problem
>
();
constexpr
index_t
smem_size_ds
=
GetSmemSizeSGrad
<
Problem
>
();
constexpr
index_t
smem_size_bias
=
GetSmemSizeBias
<
Problem
>
();
constexpr
index_t
smem_size_stage0_0
=
smem_size_k
+
smem_size_kt
;
constexpr
index_t
smem_size_stage0_1
=
smem_size_v
;
constexpr
index_t
smem_size_stage1
=
smem_size_qt
+
smem_size_q
+
+
smem_size_dot
+
smem_size_do
+
smem_size_lse
+
smem_size_d
+
max
(
smem_size_bias
,
smem_size_ds
);
return
max
(
smem_size_stage0_0
,
smem_size_stage0_1
,
smem_size_stage1
);
}
template
<
typename
Problem_
>
struct
HotLoopScheduler
{
using
Problem
=
Problem_
;
template
<
index_t
GemmStage
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
()
{
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
0
>
()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
constexpr
index_t
VMEM_READ_INST
=
Q_VMEM_READ
+
OGrad_VMEM_READ
+
LSE_VMEM_READ
+
D_VMEM_READ
;
constexpr
index_t
LDS_READ_INST
=
OGradT_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm0MFMA
;
// Evenly distributed to relieve SQ->TA FIFO pressure
constexpr
index_t
MFMA_PER_VMEM_READ
=
MFMA_INST
/
VMEM_READ_INST
;
constexpr
index_t
MFMA_Remainder
=
MFMA_INST
-
MFMA_PER_VMEM_READ
*
VMEM_READ_INST
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
static_for
<
0
,
VMEM_READ_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
static_for
<
0
,
MFMA_PER_VMEM_READ
,
1
>
{}([
&
](
auto
j
)
{
ignore
=
j
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
});
static_for
<
0
,
MFMA_Remainder
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
1
>
()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
constexpr
index_t
LDS_READ_INST
=
QT_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm1MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
2
>
()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
constexpr
index_t
LDS_WRITE_INST
=
Q_LDS_WRITE
+
QT_LDS_WRITE
+
OGrad_LDS_WRITE
+
OGradT_LDS_WRITE
+
LSE_LDS_WRITE
+
D_LDS_WRITE
;
constexpr
index_t
MFMA_INST
=
Gemm2MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS write
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
3
>
()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr
index_t
LDS_WRITE_INST
=
SGradT_LDS_WRITE
;
constexpr
index_t
LDS_READ_INST
=
SGradT_LDS_READ_P1
+
Q_LDS_READ
+
LSE_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm3MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
>=
1
?
LDS_WRITE_INST
/
MFMA_INST
:
1
;
constexpr
index_t
MFMA_INST_LDS_WRITE
=
LDS_WRITE_INST
/
LDS_WRITE_PER_MFMA
;
constexpr
index_t
LDS_READ_PER_MFMA
=
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
>
0
?
LDS_READ_INST
/
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
>
0
?
LDS_READ_INST
/
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
:
1
:
0
;
static_for
<
0
,
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS Write
});
static_for
<
0
,
MFMA_INST
-
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
4
>
()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr
index_t
LDS_READ_INST
=
SGradT_LDS_READ_P2
+
OGrad_LDS_READ
+
D_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm4MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
>
0
?
LDS_READ_INST
/
MFMA_INST
:
1
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
});
}
private:
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
Problem
::
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpGemmN
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpGemmK
=
WarpGemmM
==
16
?
16
:
8
;
static
constexpr
index_t
Gemm4MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Gemm4NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
// Compute
static
constexpr
index_t
Gemm0MFMA
=
kM0
*
kN0
*
kQKHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm1MFMA
=
kM0
*
kN0
*
kVHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kN0
*
kVHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm3MFMA
=
kN0
*
kQKHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm4MFMA
=
kM0
*
kQKHeaddim
*
kN0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
// VMEM
static
constexpr
index_t
Q_VMEM_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
OGrad_VMEM_READ
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
LSE_VMEM_READ
=
1
;
static
constexpr
index_t
D_VMEM_READ
=
1
;
// LDS Read
static
constexpr
index_t
OGradT_LDS_READ
=
kM0
*
kVHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
QT_LDS_READ
=
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
SGradT_LDS_READ_P1
=
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
SGradT_LDS_READ_P2
=
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
static
constexpr
index_t
Q_LDS_WRITE
=
kM0
*
kQKHeaddim
/
Problem
::
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
QT_LDS_WRITE
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_WRITE
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
OGradT_LDS_WRITE
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetTransposedAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_WRITE
=
1
;
static
constexpr
index_t
D_LDS_WRITE
=
1
;
static
constexpr
index_t
SGradT_LDS_WRITE
=
kM0
*
kN0
/
kBlockSize
;
};
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
View file @
f6ceef78
...
...
@@ -8,9 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum
class
BlockFmhaBwdPipelineEnum
{
KSKTSVR
=
0
,
QSKSVROGradS
,
KSVR
,
KRKTRVR_IGLP
=
0
,
KRKTRVR
,
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
View file @
f6ceef78
...
...
@@ -24,7 +24,9 @@ template <typename QDataType_,
typename
BiasGradDataType_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
typename
Traits_
>
struct
BlockFmhaBwdPipelineProblem
{
...
...
@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using
BiasGradDataType
=
remove_cvref_t
<
BiasGradDataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
remove_cvref_t
<
FmhaDropout_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
...
@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Traits
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
...
...
@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
template
<
typename
AccDataType_
,
typename
QGradDataType_
,
index_t
kBlockSize_
,
index_t
kM0_
,
index_t
kN0_
,
index_t
kQKHeaddim_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
,
typename
Traits_
>
struct
BlockFmhaBwdConvertQGradPipelineProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
QGradDataType
=
remove_cvref_t
<
QGradDataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static_assert
(
0
<
kBlockSize_
&&
kBlockSize_
%
get_warp_size
()
==
0
,
"kBlockSize should be divisible by get_warp_size()"
);
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
index_t
kN0
=
kN0_
;
static
constexpr
index_t
kQKHeaddim
=
kQKHeaddim_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
f6ceef78
...
...
@@ -231,7 +231,9 @@ struct BlockFmhaPipelineQRKSVSAsync
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
f6ceef78
...
...
@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
struct
TileFmhaBwdConvertQGradTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
f6ceef78
...
...
@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
0 → 100644
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
// M->N Warp
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
// check ABC-block-distribution
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
a_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
ABlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"A distribution is wrong!"
);
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
b_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
BBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"B distribution is wrong!"
);
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
BWarpDstr
=
typename
WG
::
BWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
BWarpTensor
=
typename
WG
::
BWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A Block window
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B block tensor
BWarpTensor
b_warp_tensor
;
b_warp_tensor
.
get_thread_buffer
()
=
b_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
nIter
,
kIter
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor
,
b_block_tensor
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
0 → 100644
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
AType_
,
typename
BType_
,
typename
CType_
,
typename
BlockWarps_
,
typename
WarpGemm_
>
struct
BlockGemmARegBRegCRegV1CustomPolicy
{
using
AType
=
remove_cvref_t
<
AType_
>
;
using
BType
=
remove_cvref_t
<
BType_
>
;
using
CType
=
remove_cvref_t
<
CType_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
static
constexpr
index_t
kMWarps
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNWarps
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kKWarps
=
BlockWarps
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
WarpGemm_
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
return
make_tuple
(
WarpGemm
{},
kMWarps
,
kNWarps
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
0 → 100644
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmARegBRegCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmARegBRegCRegV1DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
View file @
f6ceef78
...
...
@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
1
>
{}];
//
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
//
KPerBlock == BlockGemmShape::kK,
//
"wrong!");
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
View file @
f6ceef78
...
...
@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
//
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
//
KPerBlock == BlockGemmShape::kK,
//
"wrong!");
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment