Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
545eec16
Commit
545eec16
authored
Feb 17, 2025
by
aska-0096
Browse files
change q, do lds layout
parent
72428037
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
311 additions
and
212 deletions
+311
-212
example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp
example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp
+104
-89
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+38
-36
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+6
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+163
-85
No files found.
example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp
View file @
545eec16
...
@@ -17,41 +17,33 @@
...
@@ -17,41 +17,33 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
// Convert DQ
// Convert DQ
using
fmha_dtype_0
=
FmhaBwdFp16
;
using
fmha_dtype_0
=
FmhaBwdFp16
;
using
fmha_bwd_convert_dq_trait_0
=
using
fmha_bwd_convert_dq_trait_0
=
ck_tile
::
TileFmhaBwdConvertQGradTraits
<
false
,
false
,
2
>
;
ck_tile
::
TileFmhaBwdConvertQGradTraits
<
false
,
false
,
2
>
;
using
fmha_bwd_convert_dq_pipeline_problem_0
=
ck_tile
::
BlockFmhaBwdConvertQGradPipelineProblem
<
using
fmha_bwd_convert_dq_pipeline_problem_0
=
typename
FmhaBwdTypeConfig
<
fmha_dtype_0
>::
AccDataType
,
ck_tile
::
BlockFmhaBwdConvertQGradPipelineProblem
<
typename
FmhaBwdTypeConfig
<
fmha_dtype_0
>::
QGradDataType
,
typename
FmhaBwdTypeConfig
<
fmha_dtype_0
>::
AccDataType
,
/* BlockSize = */
256
,
typename
FmhaBwdTypeConfig
<
fmha_dtype_0
>::
QGradDataType
,
64
,
/* BlockSize = */
256
,
128
,
64
,
128
,
128
,
false
,
128
,
false
,
false
,
fmha_bwd_convert_dq_trait_0
>
;
false
,
fmha_bwd_convert_dq_trait_0
>
;
using
fmha_bwd_convert_dq_0
=
using
fmha_bwd_convert_dq_0
=
typename
ck_tile
::
BlockFmhaBwdConvertQGrad
<
fmha_bwd_convert_dq_pipeline_problem_0
>
;
typename
ck_tile
::
BlockFmhaBwdConvertQGrad
<
fmha_bwd_convert_dq_pipeline_problem_0
>
;
using
fmha_bwd_convert_dq_kernel_0
=
using
fmha_bwd_convert_dq_kernel_0
=
ck_tile
::
FmhaBwdConvertQGradKernel
<
fmha_bwd_convert_dq_0
>
;
ck_tile
::
FmhaBwdConvertQGradKernel
<
fmha_bwd_convert_dq_0
>
;
using
convert_dq_trait_0
=
fmha_bwd_convert_dq_traits_
<
128
,
using
convert_dq_trait_0
=
FmhaBwdFp16
,
fmha_bwd_convert_dq_traits_
<
128
,
FmhaBwdFp16
,
false
,
false
,
false
,
false
>
;
false
,
false
,
false
,
false
>
;
template
<
>
template
<
>
void
fmha_bwd_convert_dq_oneshot_
<
convert_dq_trait_0
>
(
const
ck_tile
::
stream_config
&
s
,
void
fmha_bwd_convert_dq_oneshot_
<
convert_dq_trait_0
>
(
const
ck_tile
::
stream_config
&
s
,
fmha_bwd_args
a
)
fmha_bwd_args
a
)
{
{
using
k_
=
fmha_bwd_convert_dq_kernel_0
;
using
k_
=
fmha_bwd_convert_dq_kernel_0
;
auto
[
kargs
,
grids
]
=
fmha_bwd_convert_dq_create_kargs_and_grids
<
k_
>
(
a
);
auto
[
kargs
,
grids
]
=
fmha_bwd_convert_dq_create_kargs_and_grids
<
k_
>
(
a
);
...
@@ -69,8 +61,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
...
@@ -69,8 +61,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
}
}
// dq_dk_dv
// dq_dk_dv
using
fmha_block_tile_0
=
ck_tile
::
using
fmha_block_tile_0
=
ck_tile
::
sequence
<
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
>
;
sequence
<
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
>
;
using
fmha_block_warps0_0
=
ck_tile
::
sequence
<
1
,
4
,
1
>
;
using
fmha_block_warps0_0
=
ck_tile
::
sequence
<
1
,
4
,
1
>
;
using
fmha_block_warps1_0
=
ck_tile
::
sequence
<
4
,
1
,
1
>
;
using
fmha_block_warps1_0
=
ck_tile
::
sequence
<
4
,
1
,
1
>
;
using
fmha_block_warps2_0
=
ck_tile
::
sequence
<
1
,
4
,
1
>
;
using
fmha_block_warps2_0
=
ck_tile
::
sequence
<
1
,
4
,
1
>
;
...
@@ -82,29 +73,29 @@ using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
...
@@ -82,29 +73,29 @@ using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// G1&G3 -> GdKV
// G1&G3 -> GdKV
// G4 -> GdQ
// G4 -> GdQ
using
fmha_bwd_shape_0
=
ck_tile
::
TileFmhaBwdShape
<
fmha_block_tile_0
,
using
fmha_bwd_shape_0
=
ck_tile
::
TileFmhaBwdShape
<
fmha_block_tile_0
,
fmha_block_warps0_0
,
fmha_block_warps0_0
,
fmha_warp_tile0_0
,
fmha_warp_tile0_0
,
fmha_block_warps1_0
,
fmha_block_warps1_0
,
fmha_warp_tile1_0
,
fmha_warp_tile1_0
,
fmha_block_warps0_0
,
fmha_block_warps0_0
,
fmha_warp_tile0_0
,
fmha_warp_tile0_0
,
fmha_block_warps1_0
,
fmha_block_warps1_0
,
fmha_warp_tile1_0
,
fmha_warp_tile1_0
,
fmha_block_warps2_0
,
fmha_block_warps2_0
,
fmha_warp_tile0_0
>
;
fmha_warp_tile0_0
>
;
using
fmha_bwd_trait_0
=
ck_tile
::
TileFmhaTraits
<
false
,
using
fmha_bwd_trait_0
=
ck_tile
::
TileFmhaTraits
<
false
,
false
,
false
,
false
,
false
,
false
,
false
,
ck_tile
::
BlockAttentionBiasEnum
::
NO_BIAS
,
ck_tile
::
BlockAttentionBiasEnum
::
NO_BIAS
,
false
,
false
,
false
,
false
,
false
,
false
,
false
,
false
,
1
>
;
1
>
;
using
fmha_mask_0
=
ck_tile
::
SimplifiedGenericAttentionMask
<
false
>
;
using
fmha_mask_0
=
ck_tile
::
SimplifiedGenericAttentionMask
<
false
>
;
using
fmha_dropout_0
=
ck_tile
::
BlockDropoutBwd
<
false
,
true
,
false
>
;
using
fmha_dropout_0
=
ck_tile
::
BlockDropoutBwd
<
false
,
true
,
false
>
;
using
fmha_bwd_pipeline_problem_0
=
ck_tile
::
BlockFmhaBwdPipelineProblem
<
using
fmha_bwd_pipeline_problem_0
=
ck_tile
::
BlockFmhaBwdPipelineProblem
<
typename
FmhaBwdTypeConfig
<
fmha_dtype_0
>::
QDataType
,
typename
FmhaBwdTypeConfig
<
fmha_dtype_0
>::
QDataType
,
...
@@ -129,7 +120,8 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
...
@@ -129,7 +120,8 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
fmha_dropout_0
,
fmha_dropout_0
,
fmha_bwd_trait_0
>
;
fmha_bwd_trait_0
>
;
using
fmha_bwd_pipeline_0
=
ck_tile
::
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
<
fmha_bwd_pipeline_problem_0
>
;
using
fmha_bwd_pipeline_0
=
ck_tile
::
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
<
fmha_bwd_pipeline_problem_0
>
;
using
fmha_bwd_dk_epilogue_0
=
ck_tile
::
Default2DEpilogue
<
using
fmha_bwd_dk_epilogue_0
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
typename
FmhaBwdTypeConfig
<
FmhaBwdFp16
>::
AccDataType
,
ck_tile
::
Default2DEpilogueProblem
<
typename
FmhaBwdTypeConfig
<
FmhaBwdFp16
>::
AccDataType
,
...
@@ -143,28 +135,25 @@ using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
...
@@ -143,28 +135,25 @@ using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
false
,
false
,
false
>>
;
false
>>
;
using
fmha_bwd_dq_dk_dv_kernel_0
=
using
fmha_bwd_dq_dk_dv_kernel_0
=
ck_tile
::
ck_tile
::
FmhaBwdDQDKDVKernel
<
fmha_bwd_pipeline_0
,
FmhaBwdDQDKDVKernel
<
fmha_bwd_pipeline_0
,
fmha_bwd_dk_epilogue_0
,
fmha_bwd_dv_epilogue_0
>
;
fmha_bwd_dk_epilogue_0
,
fmha_bwd_dv_epilogue_0
>
;
using
dq_dk_dv_trait_0
=
fmha_bwd_dq_dk_dv_traits_
<
128
,
using
dq_dk_dv_trait_0
=
fmha_bwd_dq_dk_dv_traits_
<
128
,
FmhaBwdFp16
,
FmhaBwdFp16
,
false
,
false
,
ck_tile
::
BlockFmhaBwdPipelineEnum
::
KRKTRVR_IGLP
,
ck_tile
::
BlockFmhaBwdPipelineEnum
::
KRKTRVR_IGLP
,
fmha_mask_0
,
fmha_mask_0
,
fmha_dropout_0
,
fmha_dropout_0
,
ck_tile
::
BlockAttentionBiasEnum
::
NO_BIAS
,
ck_tile
::
BlockAttentionBiasEnum
::
NO_BIAS
,
false
,
false
,
false
,
false
,
false
,
false
,
false
,
false
,
false
,
false
,
false
>
;
false
>
;
template
<
>
template
<
>
void
fmha_bwd_dq_dk_dv_oneshot_
<
dq_dk_dv_trait_0
>
(
const
ck_tile
::
stream_config
&
s
,
void
fmha_bwd_dq_dk_dv_oneshot_
<
dq_dk_dv_trait_0
>
(
const
ck_tile
::
stream_config
&
s
,
fmha_bwd_args
a
)
fmha_bwd_args
a
)
{
{
using
k_
=
fmha_bwd_dq_dk_dv_kernel_0
;
using
k_
=
fmha_bwd_dq_dk_dv_kernel_0
;
auto
[
kargs
,
grids
]
=
fmha_bwd_dq_dk_dv_create_kargs_and_grids
<
k_
>
(
a
);
auto
[
kargs
,
grids
]
=
fmha_bwd_dq_dk_dv_create_kargs_and_grids
<
k_
>
(
a
);
...
@@ -182,8 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
...
@@ -182,8 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
}
}
// dot_do_o
// dot_do_o
using
fmha_bwd_dot_do_o_trait_0
=
using
fmha_bwd_dot_do_o_trait_0
=
ck_tile
::
TileFmhaBwdOGradDotOTraits
<
false
,
false
,
2
>
;
ck_tile
::
TileFmhaBwdOGradDotOTraits
<
false
,
false
,
2
>
;
using
fmha_bwd_dot_do_o_pipeline_problem_0
=
ck_tile
::
BlockFmhaBwdOGradDotOPipelineProblem
<
using
fmha_bwd_dot_do_o_pipeline_problem_0
=
ck_tile
::
BlockFmhaBwdOGradDotOPipelineProblem
<
typename
FmhaBwdTypeConfig
<
fmha_dtype_0
>::
ODataType
,
typename
FmhaBwdTypeConfig
<
fmha_dtype_0
>::
ODataType
,
...
@@ -197,11 +185,9 @@ using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipel
...
@@ -197,11 +185,9 @@ using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipel
using
fmha_bwd_dot_do_o_0
=
using
fmha_bwd_dot_do_o_0
=
typename
ck_tile
::
BlockFmhaBwdOGradDotO
<
fmha_bwd_dot_do_o_pipeline_problem_0
>
;
typename
ck_tile
::
BlockFmhaBwdOGradDotO
<
fmha_bwd_dot_do_o_pipeline_problem_0
>
;
using
fmha_bwd_dot_do_o_kernel_0
=
using
fmha_bwd_dot_do_o_kernel_0
=
ck_tile
::
FmhaBwdOGradDotOKernel
<
fmha_bwd_dot_do_o_0
>
;
ck_tile
::
FmhaBwdOGradDotOKernel
<
fmha_bwd_dot_do_o_0
>
;
using
dot_do_o_trait_0
=
using
dot_do_o_trait_0
=
fmha_bwd_dot_do_o_traits_
<
128
,
FmhaBwdFp16
,
false
,
false
,
false
>
;
fmha_bwd_dot_do_o_traits_
<
128
,
FmhaBwdFp16
,
false
,
false
,
false
>
;
template
<
>
template
<
>
void
fmha_bwd_dot_do_o_oneshot_
<
dot_do_o_trait_0
>
(
const
ck_tile
::
stream_config
&
s
,
fmha_bwd_args
a
)
void
fmha_bwd_dot_do_o_oneshot_
<
dot_do_o_trait_0
>
(
const
ck_tile
::
stream_config
&
s
,
fmha_bwd_args
a
)
...
@@ -221,7 +207,6 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_0>()
...
@@ -221,7 +207,6 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_0>()
return
k_
::
GetName
();
return
k_
::
GetName
();
}
}
template
<
typename
T
>
template
<
typename
T
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
{
{
...
@@ -244,25 +229,53 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d
...
@@ -244,25 +229,53 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d
float
fmha_bwd_
(
const
ck_tile
::
stream_config
&
s
,
fmha_bwd_args
a
)
float
fmha_bwd_
(
const
ck_tile
::
stream_config
&
s
,
fmha_bwd_args
a
)
{
{
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
fmha_bwd_dot_do_o_get_name_
<
dot_do_o_trait_
>
()
<<
", "
<<
fmha_bwd_dq_dk_dv_get_name_
<
dq_dk_dv_trait_
>
()
<<
", "
<<
fmha_bwd_convert_dq_get_name_
<
convert_dq_trait_
>
()
<<
std
::
flush
;
std
::
cout
<<
", "
<<
fmha_bwd_dot_do_o_get_name_
<
dot_do_o_trait_
>
()
<<
", "
return
ck_tile
::
launch_kernel
(
s
,
<<
fmha_bwd_dq_dk_dv_get_name_
<
dq_dk_dv_trait_
>
()
<<
", "
[
=
](
const
ck_tile
::
stream_config
&
s_
){
fmha_bwd_dot_do_o_oneshot_
<
dot_do_o_trait_
>
(
s_
,
a
);
},
<<
fmha_bwd_convert_dq_get_name_
<
convert_dq_trait_
>
()
<<
std
::
flush
;
[
=
](
const
ck_tile
::
stream_config
&
s_
){
fmha_bwd_dq_dk_dv_oneshot_
<
dq_dk_dv_trait_
>
(
s_
,
a
);
},
return
ck_tile
::
launch_kernel
(
[
=
](
const
ck_tile
::
stream_config
&
s_
){
fmha_bwd_convert_dq_oneshot_
<
convert_dq_trait_
>
(
s_
,
a
);
}
s
,
);
[
=
](
const
ck_tile
::
stream_config
&
s_
)
{
fmha_bwd_dot_do_o_oneshot_
<
dot_do_o_trait_
>
(
s_
,
a
);
},
[
=
](
const
ck_tile
::
stream_config
&
s_
)
{
fmha_bwd_dq_dk_dv_oneshot_
<
dq_dk_dv_trait_
>
(
s_
,
a
);
},
[
=
](
const
ck_tile
::
stream_config
&
s_
)
{
fmha_bwd_convert_dq_oneshot_
<
convert_dq_trait_
>
(
s_
,
a
);
});
}
}
float
fmha_bwd
(
fmha_bwd_traits
t
,
fmha_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
){
float
fmha_bwd
(
fmha_bwd_traits
t
,
fmha_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
float
r
=
-
1
;
float
r
=
-
1
;
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
&&
(
t
.
is_group_mode
==
false
)
&&
(
t
.
mask_type
==
mask_enum
::
no_mask
)
&&
(
t
.
bias_type
==
bias_enum
::
no_bias
)
&&
(
t
.
has_dbias
==
false
)
&&
(
t
.
has_dropout
==
false
)
&&
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
&&
(
t
.
is_group_mode
==
false
)
&&
(
a
.
seqlen_q
%
16
==
0
and
a
.
seqlen_q
%
64
==
0
)
&&
(
a
.
seqlen_k
%
128
==
0
)
&&
(
a
.
hdim_q
%
128
==
0
)
&&
(
a
.
hdim_v
%
128
==
0
)
&&
(
t
.
is_deterministic
==
false
))
{
(
t
.
mask_type
==
mask_enum
::
no_mask
)
&&
(
t
.
bias_type
==
bias_enum
::
no_bias
)
&&
(
t
.
has_dbias
==
false
)
&&
(
t
.
has_dropout
==
false
)
&&
(
a
.
seqlen_q
%
16
==
0
and
a
.
seqlen_q
%
64
==
0
)
&&
(
a
.
seqlen_k
%
128
==
0
)
&&
(
a
.
hdim_q
%
128
==
0
)
&&
(
a
.
hdim_v
%
128
==
0
)
&&
(
t
.
is_deterministic
==
false
))
{
using
dot_do_o_trait_
=
fmha_bwd_dot_do_o_traits_
<
128
,
FmhaBwdFp16
,
false
,
false
,
false
>
;
using
dot_do_o_trait_
=
fmha_bwd_dot_do_o_traits_
<
128
,
FmhaBwdFp16
,
false
,
false
,
false
>
;
using
dq_dk_dv_trait_
=
fmha_bwd_dq_dk_dv_traits_
<
128
,
FmhaBwdFp16
,
false
,
ck_tile
::
BlockFmhaBwdPipelineEnum
::
KRKTRVR_IGLP
,
ck_tile
::
SimplifiedGenericAttentionMask
<
false
>
,
ck_tile
::
BlockDropoutBwd
<
false
,
true
,
false
>
,
ck_tile
::
BlockAttentionBiasEnum
::
NO_BIAS
,
false
,
false
,
false
,
false
,
false
,
false
>
;
using
dq_dk_dv_trait_
=
using
convert_dq_trait_
=
fmha_bwd_convert_dq_traits_
<
128
,
FmhaBwdFp16
,
false
,
false
,
false
,
false
>
;
fmha_bwd_dq_dk_dv_traits_
<
128
,
FmhaBwdFp16
,
false
,
ck_tile
::
BlockFmhaBwdPipelineEnum
::
KRKTRVR_IGLP
,
ck_tile
::
SimplifiedGenericAttentionMask
<
false
>
,
ck_tile
::
BlockDropoutBwd
<
false
,
true
,
false
>
,
ck_tile
::
BlockAttentionBiasEnum
::
NO_BIAS
,
false
,
false
,
false
,
false
,
false
,
false
>
;
using
convert_dq_trait_
=
fmha_bwd_convert_dq_traits_
<
128
,
FmhaBwdFp16
,
false
,
false
,
false
,
false
>
;
r
=
fmha_bwd_
<
dot_do_o_trait_
,
dq_dk_dv_trait_
,
convert_dq_trait_
>
(
s
,
a
);
r
=
fmha_bwd_
<
dot_do_o_trait_
,
dq_dk_dv_trait_
,
convert_dq_trait_
>
(
s
,
a
);
return
r
;
return
r
;
}
}
else
{
else
{
assert
(
"unsupported case
\n
"
);
assert
(
"unsupported case
\n
"
);
return
r
;
return
r
;
}
}
...
@@ -806,11 +819,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -806,11 +819,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
);
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
);
// using instance:
// using instance:
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false,
// using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
// r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
// ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>,
// return r;
// ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; using
// convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false,
// false>; r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a); return r;
if
(
ave_time
<
0
)
if
(
ave_time
<
0
)
{
{
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
545eec16
...
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
...
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
template
<
typename
ADataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
stride_B
=
stride_B
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
float
ave_time
=
ALayout
,
BLayout
,
CLayout
>
(
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
std
::
size_t
num_byte
=
...
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" A_Layout ="
<<
ALayout
::
name
<<
" A_Layout ="
<<
ALayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
return
ave_time
;
}
}
...
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
...
@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
a_m_k_dev_buf
,
b_k_n_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
c_m_n_dev_buf
,
M
,
M
,
N
,
N
,
K
,
K
,
stride_A
,
stride_A
,
stride_B
,
stride_B
,
stride_C
,
stride_C
,
kbatch
,
kbatch
,
n_warmup
,
n_warmup
,
n_repeat
);
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
(
K
,
kbatch
,
max_accumulated_value
);
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
const
float
max_accumulated_value
=
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
(
K
,
kbatch
,
max_accumulated_value
);
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
545eec16
...
@@ -137,6 +137,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -137,6 +137,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
"wrong!"
);
// if (threadIdx.x == 0){
// HotLoopScheduler::print();
// }
// Block GEMM
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
...
@@ -532,7 +535,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -532,7 +535,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// Hot loop
// Hot loop
while
(
i_total_loops
<
(
num_total_loop
-
1
))
while
(
i_total_loops
<
(
num_total_loop
-
1
))
{
{
// STAGE 1, Q@K Gemm0
// STAGE 1, Q@K Gemm0
d_block_tile
=
load_tile
(
d_dram_window
);
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
...
@@ -664,7 +667,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -664,7 +667,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// decltype(p_gemm)>(pt_reg_tensor, p_gemm);
// decltype(p_gemm)>(pt_reg_tensor, p_gemm);
pt_reg_tensor
.
get_thread_buffer
()
=
p_gemm
.
get_thread_buffer
();
pt_reg_tensor
.
get_thread_buffer
()
=
p_gemm
.
get_thread_buffer
();
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
545eec16
...
@@ -202,9 +202,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -202,9 +202,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
constexpr
index_t
kVecLoad
=
?
kMaxVecLoad
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
kMinVecLoad
;
:
kMinVecLoad
;
return
kVecLoad
;
return
kVecLoad
;
}
}
...
@@ -260,9 +259,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -260,9 +259,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
constexpr
index_t
kVecLoad
=
?
kMaxVecLoad
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
kMinVecLoad
;
:
kMinVecLoad
;
return
kVecLoad
;
return
kVecLoad
;
}
}
...
@@ -607,7 +605,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -607,7 +605,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
{
return
GetAlignmentQ
<
Problem
>
();
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -649,7 +648,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -649,7 +648,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
{
{
return
GetAlignmentOGrad
<
Problem
>
();
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -666,48 +666,73 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -666,48 +666,73 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return
16
/
sizeof
(
GemmDataType
);
return
16
/
sizeof
(
GemmDataType
);
}
}
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
,
bool
XorLdsLayout
=
true
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
{
{
constexpr
auto
DataTypeSize
=
2
;
// sizeof(F16/BF16)
if
constexpr
(
XorLdsLayout
)
constexpr
auto
MNLdsLayer
=
{
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
auto
DataTypeSize
=
2
;
// sizeof(F16/BF16)
constexpr
auto
MNLdsLayer
=
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
make_tuple
(
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{},
number
<
MNPerBlock
/
MNLdsLayer
>
{},
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
number
<
KPack
>
{}),
make_tuple
(
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{},
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MNLdsLayer
>
{},
number
<
1
>
{}),
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
KPack
>
{},
number
<
KPack
>
{}),
number
<
1
>
{});
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MNLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
constexpr
auto
x_lds_block_desc_permuted
=
transform_tensor_descriptor
(
number
<
1
>
{});
x_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
constexpr
auto
x_lds_block_desc_permuted
=
transform_tensor_descriptor
(
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{})),
x_lds_block_desc_0
,
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
constexpr
auto
x_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
x_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
constexpr
auto
x_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNLdsLayer
>
{})),
x_lds_block_desc_permuted
,
make_pass_through_transform
(
number
<
MNPerBlock
/
MNLdsLayer
>
{}),
make_tuple
(
make_unmerge_transform
(
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNLdsLayer
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_pass_through_transform
(
number
<
MNPerBlock
/
MNLdsLayer
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
x_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
MNLdsLayer
>
{})),
x_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_merge_transform_v3_division_mod
(
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
MNLdsLayer
>
{})),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
return
x_lds_block_desc
;
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
x_lds_block_desc
;
}
else
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
MNPerBlock
>
{},
number
<
KPerBlock
/
64
>
{},
number
<
64
/
KPack
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPerBlock
/
64
*
(
64
/
KPack
+
1
)
*
KPack
>
{},
number
<
(
64
/
KPack
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
return
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
MNPerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
64
>
{},
number
<
64
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
}
}
template
<
typename
Problem
,
template
<
typename
Problem
,
...
@@ -986,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -986,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
k
KPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
KPack
=
GetSmemKPackQ
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
k
KPack
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
KPack
,
false
>
();
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -1193,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1193,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
k
KPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
KPack
=
GetSmemKPackOGrad
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
k
KPack
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
KPack
,
false
>
();
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -1681,14 +1706,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1681,14 +1706,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_PER_VMEM_READ
=
MFMA_INST
/
VMEM_READ_INST
;
constexpr
index_t
MFMA_PER_VMEM_READ
=
MFMA_INST
/
VMEM_READ_INST
;
constexpr
index_t
MFMA_Remainder
=
MFMA_INST
-
MFMA_PER_VMEM_READ
*
VMEM_READ_INST
;
constexpr
index_t
MFMA_Remainder
=
MFMA_INST
-
MFMA_PER_VMEM_READ
*
VMEM_READ_INST
;
// To hide instruction issue latency
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
MFMA_INST
);
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
MFMA_INST
);
static_for
<
0
,
VMEM_READ_INST
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
VMEM_READ_INST
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
static_for
<
0
,
MFMA_PER_VMEM_READ
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
MFMA_PER_VMEM_READ
,
1
>
{}([
&
](
auto
j
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
*
MFMA_PER_VMEM_READ
+
j
<
LDS_READ_INST
){
if
constexpr
(
i
*
MFMA_PER_VMEM_READ
+
j
<
LDS_READ_INST
)
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
}
}
});
});
});
});
...
@@ -1709,11 +1737,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1709,11 +1737,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm1MFMA
;
constexpr
index_t
MFMA_INST
=
Gemm1MFMA
;
// To hide instruction issue latency
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
MFMA_INST
);
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
MFMA_INST
);
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
<
LDS_READ_INST
){
if
constexpr
(
i
<
LDS_READ_INST
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
}
}
});
});
...
@@ -1729,11 +1759,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1729,11 +1759,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm2MFMA
;
constexpr
index_t
MFMA_INST
=
Gemm2MFMA
;
// To hide instruction issue latency
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_WRITE_INST
,
MFMA_INST
);
constexpr
index_t
LDS_WRITE_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_WRITE_INST
,
MFMA_INST
);
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
<
LDS_WRITE_INST
){
if
constexpr
(
i
<
LDS_WRITE_INST
)
{
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS write
}
}
});
});
...
@@ -1749,31 +1781,43 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1749,31 +1781,43 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm3MFMA
;
constexpr
index_t
MFMA_INST
=
Gemm3MFMA
;
// To hide instruction issue latency
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_WRITE_INST
,
MFMA_INST
);
constexpr
index_t
LDS_WRITE_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_WRITE_INST
,
MFMA_INST
);
constexpr
index_t
MFMA_INST_LDS_WRITE
=
LDS_WRITE_INST
/
LDS_WRITE_PER_MFMA
;
constexpr
index_t
MFMA_INST_LDS_WRITE
=
LDS_WRITE_INST
/
LDS_WRITE_PER_MFMA
;
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
));
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
));
static_for
<
0
,
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
*
LDS_WRITE_PER_MFMA
<
LDS_WRITE_INST
){
if
constexpr
(
i
*
LDS_WRITE_PER_MFMA
<
LDS_WRITE_INST
)
if
constexpr
(
(
i
+
1
)
*
LDS_WRITE_PER_MFMA
>
LDS_WRITE_INST
){
{
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_INST
-
i
*
LDS_WRITE_PER_MFMA
,
0
);
// DS Write
if
constexpr
((
i
+
1
)
*
LDS_WRITE_PER_MFMA
>
LDS_WRITE_INST
)
{
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_INST
-
i
*
LDS_WRITE_PER_MFMA
,
0
);
// DS Write
}
}
else
{
else
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS Write
{
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS Write
}
}
}
}
});
});
static_for
<
0
,
MFMA_INST
-
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
MFMA_INST
-
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
*
LDS_READ_PER_MFMA
<
LDS_READ_INST
){
if
constexpr
(
i
*
LDS_READ_PER_MFMA
<
LDS_READ_INST
)
if
constexpr
(
(
i
+
1
)
*
LDS_READ_PER_MFMA
>
LDS_READ_INST
){
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_INST
-
i
*
LDS_READ_PER_MFMA
,
0
);
// DS Read
if
constexpr
((
i
+
1
)
*
LDS_READ_PER_MFMA
>
LDS_READ_INST
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_INST
-
i
*
LDS_READ_PER_MFMA
,
0
);
// DS Read
}
}
else
{
else
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
}
}
}
}
});
});
...
@@ -1788,21 +1832,42 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1788,21 +1832,42 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm4MFMA
;
constexpr
index_t
MFMA_INST
=
Gemm4MFMA
;
// To hide instruction issue latency
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
MFMA_INST
);
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
MFMA_INST
);
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
*
LDS_READ_PER_MFMA
<
LDS_READ_INST
){
if
constexpr
(
i
*
LDS_READ_PER_MFMA
<
LDS_READ_INST
)
if
constexpr
(
(
i
+
1
)
*
LDS_READ_PER_MFMA
>
LDS_READ_INST
){
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_INST
-
i
*
LDS_READ_PER_MFMA
,
0
);
// DS Read
if
constexpr
((
i
+
1
)
*
LDS_READ_PER_MFMA
>
LDS_READ_INST
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_INST
-
i
*
LDS_READ_PER_MFMA
,
0
);
// DS Read
}
}
else
{
else
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
}
}
}
}
});
});
}
}
CK_TILE_HOST_DEVICE
static
void
print
()
{
printf
(
"LDS instruction{"
);
//
printf
(
"OGradT_LDS_READ: %d, "
,
OGradT_LDS_READ
);
printf
(
"OGrad_LDS_READ: %d, "
,
OGrad_LDS_READ
);
printf
(
"QT_LDS_READ: %d, "
,
QT_LDS_READ
);
printf
(
"Q_LDS_READ: %d, "
,
Q_LDS_READ
);
printf
(
"SGradT_LDS_READ_P1: %d, "
,
SGradT_LDS_READ_P1
);
printf
(
"SGradT_LDS_READ_P2: %d, "
,
SGradT_LDS_READ_P2
);
printf
(
"LSE_LDS_READ: %d, "
,
LSE_LDS_READ
);
printf
(
"D_LDS_READ: %d, "
,
D_LDS_READ
);
printf
(
"}"
);
}
private:
private:
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
Problem
::
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
Problem
::
BlockFmhaShape
::
kM0
;
...
@@ -1818,6 +1883,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1818,6 +1883,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static
constexpr
index_t
WarpGemmN
=
static
constexpr
index_t
WarpGemmN
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{});
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpGemmK
=
WarpGemmM
==
16
?
16
:
8
;
static
constexpr
index_t
WarpGemmK
=
WarpGemmM
==
16
?
16
:
8
;
static
constexpr
index_t
Gemm0MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Gemm2MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Gemm4MWarp
=
static
constexpr
index_t
Gemm4MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Gemm4NWarp
=
static
constexpr
index_t
Gemm4NWarp
=
...
@@ -1847,20 +1916,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1847,20 +1916,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static
constexpr
index_t
D_VMEM_READ
=
1
;
static
constexpr
index_t
D_VMEM_READ
=
1
;
// LDS Read
// LDS Read
// 16 * 128 / 64 / 4 = 8
static
constexpr
index_t
OGradT_LDS_READ
=
static
constexpr
index_t
OGradT_LDS_READ
=
kM0
*
kVHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentOGrad
<
Problem
>
();
kM0
*
kVHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentOGrad
<
Problem
>
();
// 16 * 128 / 64 / 4 = 8
static
constexpr
index_t
QT_LDS_READ
=
static
constexpr
index_t
QT_LDS_READ
=
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
// 16 * 32 / 64 / 8 = 1
static
constexpr
index_t
SGradT_LDS_READ_P1
=
static
constexpr
index_t
SGradT_LDS_READ_P1
=
// kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
// kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
2
;
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
2
;
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kK0
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
// 16 * 128 / 64 / 8 = 4
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kK0
/
(
get_warp_size
()
*
Gemm0MWarp
)
/
GetAlignmentQ
<
Problem
>
();
// 1
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// 16 * 96 / 64 / 8 = 3
static
constexpr
index_t
SGradT_LDS_READ_P2
=
static
constexpr
index_t
SGradT_LDS_READ_P2
=
// kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
// kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
2
;
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
2
;
// 16 * 128 / 64 / 8 = 4
static
constexpr
index_t
OGrad_LDS_READ
=
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
kK2
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
kM0
*
kK2
/
(
get_warp_size
()
*
Gemm2MWarp
)
/
GetAlignmentOGrad
<
Problem
>
();
// 1
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
// LDS Write
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment