Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ccd2fb13
Commit
ccd2fb13
authored
Feb 13, 2025
by
aska-0096
Browse files
temp save
parent
385ac815
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
45 deletions
+51
-45
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+1
-0
example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp
example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp
+17
-17
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+13
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+20
-17
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
ccd2fb13
...
...
@@ -117,6 +117,7 @@ target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OP
set
(
STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS
)
list
(
APPEND STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
list
(
APPEND STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1
)
list
(
APPEND STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker
)
set
(
STANDALONE_EXAMPLE_FA_BWD
"standalone_example_fa_bwd"
)
...
...
example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp
View file @
ccd2fb13
...
...
@@ -19,7 +19,7 @@
// Convert DQ
using
fmha_dtype_0
=
FmhaBwd
Bf
16
;
using
fmha_dtype_0
=
FmhaBwd
Fp
16
;
using
fmha_bwd_convert_dq_trait_0
=
ck_tile
::
TileFmhaBwdConvertQGradTraits
<
false
,
false
,
2
>
;
...
...
@@ -43,7 +43,7 @@ using fmha_bwd_convert_dq_kernel_0 =
ck_tile
::
FmhaBwdConvertQGradKernel
<
fmha_bwd_convert_dq_0
>
;
using
convert_dq_trait_0
=
fmha_bwd_convert_dq_traits_
<
128
,
FmhaBwd
Bf
16
,
FmhaBwd
Fp
16
,
false
,
false
,
false
,
...
...
@@ -132,14 +132,14 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
using
fmha_bwd_pipeline_0
=
ck_tile
::
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
<
fmha_bwd_pipeline_problem_0
>
;
using
fmha_bwd_dk_epilogue_0
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
typename
FmhaBwdTypeConfig
<
FmhaBwd
Bf
16
>::
AccDataType
,
typename
FmhaBwdTypeConfig
<
FmhaBwd
Bf
16
>::
KGradDataType
,
ck_tile
::
Default2DEpilogueProblem
<
typename
FmhaBwdTypeConfig
<
FmhaBwd
Fp
16
>::
AccDataType
,
typename
FmhaBwdTypeConfig
<
FmhaBwd
Fp
16
>::
KGradDataType
,
false
,
false
>>
;
using
fmha_bwd_dv_epilogue_0
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
typename
FmhaBwdTypeConfig
<
FmhaBwd
Bf
16
>::
AccDataType
,
typename
FmhaBwdTypeConfig
<
FmhaBwd
Bf
16
>::
VGradDataType
,
ck_tile
::
Default2DEpilogueProblem
<
typename
FmhaBwdTypeConfig
<
FmhaBwd
Fp
16
>::
AccDataType
,
typename
FmhaBwdTypeConfig
<
FmhaBwd
Fp
16
>::
VGradDataType
,
false
,
false
>>
;
...
...
@@ -149,7 +149,7 @@ using fmha_bwd_dq_dk_dv_kernel_0 =
fmha_bwd_dv_epilogue_0
>
;
using
dq_dk_dv_trait_0
=
fmha_bwd_dq_dk_dv_traits_
<
128
,
FmhaBwd
Bf
16
,
FmhaBwd
Fp
16
,
false
,
ck_tile
::
BlockFmhaBwdPipelineEnum
::
KRKTRVR_IGLP
,
fmha_mask_0
,
...
...
@@ -201,7 +201,7 @@ using fmha_bwd_dot_do_o_kernel_0 =
ck_tile
::
FmhaBwdOGradDotOKernel
<
fmha_bwd_dot_do_o_0
>
;
using
dot_do_o_trait_0
=
fmha_bwd_dot_do_o_traits_
<
128
,
FmhaBwd
Bf
16
,
false
,
false
,
false
>
;
fmha_bwd_dot_do_o_traits_
<
128
,
FmhaBwd
Fp
16
,
false
,
false
,
false
>
;
template
<
>
void
fmha_bwd_dot_do_o_oneshot_
<
dot_do_o_trait_0
>
(
const
ck_tile
::
stream_config
&
s
,
fmha_bwd_args
a
)
...
...
@@ -254,11 +254,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
float
fmha_bwd
(
fmha_bwd_traits
t
,
fmha_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
){
float
r
=
-
1
;
if
(
t
.
data_type
.
compare
(
"
b
f16"
)
==
0
&&
(
t
.
is_group_mode
==
false
)
&&
(
t
.
mask_type
==
mask_enum
::
no_mask
)
&&
(
t
.
bias_type
==
bias_enum
::
no_bias
)
&&
(
t
.
has_dbias
==
false
)
&&
(
t
.
has_dropout
==
false
)
&&
if
(
t
.
data_type
.
compare
(
"f
p
16"
)
==
0
&&
(
t
.
is_group_mode
==
false
)
&&
(
t
.
mask_type
==
mask_enum
::
no_mask
)
&&
(
t
.
bias_type
==
bias_enum
::
no_bias
)
&&
(
t
.
has_dbias
==
false
)
&&
(
t
.
has_dropout
==
false
)
&&
(
a
.
seqlen_q
%
16
==
0
and
a
.
seqlen_q
%
64
==
0
)
&&
(
a
.
seqlen_k
%
128
==
0
)
&&
(
a
.
hdim_q
%
128
==
0
)
&&
(
a
.
hdim_v
%
128
==
0
)
&&
(
t
.
is_deterministic
==
false
))
{
using
dot_do_o_trait_
=
fmha_bwd_dot_do_o_traits_
<
128
,
FmhaBwd
Bf
16
,
false
,
false
,
false
>
;
using
dq_dk_dv_trait_
=
fmha_bwd_dq_dk_dv_traits_
<
128
,
FmhaBwd
Bf
16
,
false
,
ck_tile
::
BlockFmhaBwdPipelineEnum
::
KRKTRVR_IGLP
,
ck_tile
::
SimplifiedGenericAttentionMask
<
false
>
,
ck_tile
::
BlockDropoutBwd
<
false
,
true
,
false
>
,
ck_tile
::
BlockAttentionBiasEnum
::
NO_BIAS
,
false
,
false
,
false
,
false
,
false
,
false
>
;
using
convert_dq_trait_
=
fmha_bwd_convert_dq_traits_
<
128
,
FmhaBwd
Bf
16
,
false
,
false
,
false
,
false
>
;
using
dot_do_o_trait_
=
fmha_bwd_dot_do_o_traits_
<
128
,
FmhaBwd
Fp
16
,
false
,
false
,
false
>
;
using
dq_dk_dv_trait_
=
fmha_bwd_dq_dk_dv_traits_
<
128
,
FmhaBwd
Fp
16
,
false
,
ck_tile
::
BlockFmhaBwdPipelineEnum
::
KRKTRVR_IGLP
,
ck_tile
::
SimplifiedGenericAttentionMask
<
false
>
,
ck_tile
::
BlockDropoutBwd
<
false
,
true
,
false
>
,
ck_tile
::
BlockAttentionBiasEnum
::
NO_BIAS
,
false
,
false
,
false
,
false
,
false
,
false
>
;
using
convert_dq_trait_
=
fmha_bwd_convert_dq_traits_
<
128
,
FmhaBwd
Fp
16
,
false
,
false
,
false
,
false
>
;
r
=
fmha_bwd_
<
dot_do_o_trait_
,
dq_dk_dv_trait_
,
convert_dq_trait_
>
(
s
,
a
);
return
r
;
}
...
...
@@ -345,7 +345,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
}
template
<
>
auto
get_elimit
<
FmhaBwd
Bf
16
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
auto
get_elimit
<
FmhaBwd
Fp
16
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
...
...
@@ -806,9 +806,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
);
// using instance:
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwd
Bf
16, false, false, false>;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwd
Bf
16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
// using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwd
Bf
16, false, false, false, false>;
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwd
Fp
16, false, false, false>;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwd
Fp
16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
// using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwd
Fp
16, false, false, false, false>;
// r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
// return r;
if
(
ave_time
<
0
)
...
...
@@ -1231,7 +1231,7 @@ int main(int argc, char* argv[])
}
else
if
(
data_type
==
"bf16"
)
{
return
run
<
FmhaBwd
Bf
16
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaBwd
Fp
16
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
ccd2fb13
...
...
@@ -533,20 +533,21 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
while
(
i_total_loops
<
(
num_total_loop
-
1
))
{
// STAGE 1, Q@K Gemm0
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
__builtin_amdgcn_sched_barrier
(
0
);
auto
s_acc
=
SPBlockTileType
{};
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
s_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
...
...
@@ -658,12 +659,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}();
// STAGE 3, P^T@OGrad^T Gemm1
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
p_gemm
)>(
pt_reg_tensor
,
p_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
// Policy::template PTFromGemm0CToGemm1A<Problem,
// decltype(pt_reg_tensor),
// decltype(p_gemm)>(pt_reg_tensor, p_gemm);
pt_reg_tensor
.
get_thread_buffer
()
=
p_gemm
.
get_thread_buffer
();
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
ccd2fb13
...
...
@@ -204,7 +204,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
)
;
:
kMinVecLoad
;
return
kVecLoad
;
}
...
...
@@ -262,7 +262,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
)
;
:
kMinVecLoad
;
return
kVecLoad
;
}
...
...
@@ -1292,7 +1292,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
// constexpr index_t kNPerBlock = 32;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
...
...
@@ -1673,7 +1672,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
constexpr
index_t
VMEM_READ_INST
=
Q_VMEM_READ
+
OGrad_VMEM_READ
+
LSE_VMEM_READ
+
D_VMEM_READ
;
// Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
Q_VMEM_READ
+
OGrad_VMEM_READ
;
constexpr
index_t
LDS_READ_INST
=
OGradT_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm0MFMA
;
...
...
@@ -1681,17 +1681,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_PER_VMEM_READ
=
MFMA_INST
/
VMEM_READ_INST
;
constexpr
index_t
MFMA_Remainder
=
MFMA_INST
-
MFMA_PER_VMEM_READ
*
VMEM_READ_INST
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
MFMA_INST
)
;
static_for
<
0
,
VMEM_READ_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
static_for
<
0
,
MFMA_PER_VMEM_READ
,
1
>
{}([
&
](
auto
j
)
{
ignore
=
j
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
*
MFMA_PER_VMEM_READ
+
j
<
LDS_READ_INST
){
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
}
});
});
static_for
<
0
,
MFMA_Remainder
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
...
...
@@ -1708,12 +1709,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm1MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
constexpr
index_t
LDS_READ_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_READ_INST
,
MFMA_INST
)
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
<
LDS_READ_INST
){
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
}
});
}
...
...
@@ -1727,12 +1729,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm2MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
;
constexpr
index_t
LDS_WRITE_PER_MFMA
=
ck_tile
::
integer_divide_ceil
(
LDS_WRITE_INST
,
MFMA_INST
)
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
if
constexpr
(
i
<
LDS_WRITE_INST
){
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS write
}
});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment