"tests/pipelines/vscode:/vscode.git/clone" did not exist on "24895a1f494062d73028e31880c8848c6a674750"
Commit 545eec16 authored by aska-0096's avatar aska-0096
Browse files

change q, do lds layout

parent 72428037
...@@ -17,41 +17,33 @@ ...@@ -17,41 +17,33 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
// Convert DQ // Convert DQ
using fmha_dtype_0 = FmhaBwdFp16; using fmha_dtype_0 = FmhaBwdFp16;
using fmha_bwd_convert_dq_trait_0 = using fmha_bwd_convert_dq_trait_0 = ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
using fmha_bwd_convert_dq_pipeline_problem_0 = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
using fmha_bwd_convert_dq_pipeline_problem_0 = typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType, /* BlockSize = */ 256,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType, 64,
/* BlockSize = */ 256, 128,
64, 128,
128, false,
128, false,
false, fmha_bwd_convert_dq_trait_0>;
false,
fmha_bwd_convert_dq_trait_0>;
using fmha_bwd_convert_dq_0 = using fmha_bwd_convert_dq_0 =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>; typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
using fmha_bwd_convert_dq_kernel_0 = using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, using convert_dq_trait_0 =
FmhaBwdFp16, fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
false,
false,
false,
false>;
template <> template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s, void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a) fmha_bwd_args a)
{ {
using k_ = fmha_bwd_convert_dq_kernel_0; using k_ = fmha_bwd_convert_dq_kernel_0;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
...@@ -69,8 +61,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>() ...@@ -69,8 +61,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
} }
// dq_dk_dv // dq_dk_dv
using fmha_block_tile_0 = ck_tile:: using fmha_block_tile_0 = ck_tile::sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
...@@ -82,29 +73,29 @@ using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; ...@@ -82,29 +73,29 @@ using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// G1&G3 -> GdKV // G1&G3 -> GdKV
// G4 -> GdQ // G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0, using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0, fmha_block_warps0_0,
fmha_warp_tile0_0, fmha_warp_tile0_0,
fmha_block_warps1_0, fmha_block_warps1_0,
fmha_warp_tile1_0, fmha_warp_tile1_0,
fmha_block_warps0_0, fmha_block_warps0_0,
fmha_warp_tile0_0, fmha_warp_tile0_0,
fmha_block_warps1_0, fmha_block_warps1_0,
fmha_warp_tile1_0, fmha_warp_tile1_0,
fmha_block_warps2_0, fmha_block_warps2_0,
fmha_warp_tile0_0>; fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false, using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false, false,
false, false,
false, false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS, ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false, false,
false, false,
false, false,
false, false,
1>; 1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>; using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>; using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType, typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
...@@ -129,7 +120,8 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< ...@@ -129,7 +120,8 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
fmha_dropout_0, fmha_dropout_0,
fmha_bwd_trait_0>; fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>; using fmha_bwd_pipeline_0 =
ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType, ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
...@@ -143,28 +135,25 @@ using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< ...@@ -143,28 +135,25 @@ using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
false, false,
false>>; false>>;
using fmha_bwd_dq_dk_dv_kernel_0 = using fmha_bwd_dq_dk_dv_kernel_0 = ck_tile::
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0, FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0, fmha_bwd_dk_epilogue_0, fmha_bwd_dv_epilogue_0>;
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
FmhaBwdFp16, FmhaBwdFp16,
false, false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0, fmha_mask_0,
fmha_dropout_0, fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS, ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false, false,
false, false,
false, false,
false, false,
false, false,
false>; false>;
template <> template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
fmha_bwd_args a)
{ {
using k_ = fmha_bwd_dq_dk_dv_kernel_0; using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
...@@ -182,8 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>() ...@@ -182,8 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
} }
// dot_do_o // dot_do_o
using fmha_bwd_dot_do_o_trait_0 = using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits<false, false, 2>;
ck_tile::TileFmhaBwdOGradDotOTraits<false, false, 2>;
using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType, typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
...@@ -197,11 +185,9 @@ using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipel ...@@ -197,11 +185,9 @@ using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipel
using fmha_bwd_dot_do_o_0 = using fmha_bwd_dot_do_o_0 =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_0>; typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_0>;
using fmha_bwd_dot_do_o_kernel_0 = using fmha_bwd_dot_do_o_kernel_0 = ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
using dot_do_o_trait_0 = using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
template <> template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a) void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
...@@ -221,7 +207,6 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_0>() ...@@ -221,7 +207,6 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_0>()
return k_::GetName(); return k_::GetName();
} }
template <typename T> template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{ {
...@@ -244,25 +229,53 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d ...@@ -244,25 +229,53 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{ {
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", "
return ck_tile::launch_kernel(s, << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", "
[=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }, << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
[=](const ck_tile::stream_config& s_){ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }, return ck_tile::launch_kernel(
[=](const ck_tile::stream_config& s_){ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); } s,
); [=](const ck_tile::stream_config& s_) {
fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a);
},
[=](const ck_tile::stream_config& s_) {
fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a);
},
[=](const ck_tile::stream_config& s_) {
fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a);
});
} }
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s)
{
float r = -1; float r = -1;
if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) &&
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) &&
(t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) &&
(a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false))
{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; using dq_dk_dv_trait_ =
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; fmha_bwd_dq_dk_dv_traits_<128,
FmhaBwdFp16,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
ck_tile::SimplifiedGenericAttentionMask<false>,
ck_tile::BlockDropoutBwd<false, true, false>,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
false,
false>;
using convert_dq_trait_ =
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a); r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r; return r;
} }
else{ else
{
assert("unsupported case\n"); assert("unsupported case\n");
return r; return r;
} }
...@@ -806,11 +819,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -806,11 +819,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
// using instance: // using instance:
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; // using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; // using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false,
// using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; // ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
// r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a); // ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>,
// return r; // ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; using
// convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false,
// false>; r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a); return r;
if(ave_time < 0) if(ave_time < 0)
{ {
std::cout << ", not supported yet" << std::flush << std::endl; std::cout << ", not supported yet" << std::flush << std::endl;
......
...@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K, ...@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType, template <typename ADataType,
typename ALayout, typename BLayout, typename CLayout> typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf,
...@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B; args.stride_B = stride_B;
args.stride_C = stride_C; args.stride_C = stride_C;
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType, float ave_time =
ALayout, BLayout, CLayout>( gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = std::size_t num_byte =
...@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " C_Layout =" << CLayout::name << " B Type = " << DataTypeTraits<BDataType>::name
<< " A Type = " << DataTypeTraits<ADataType>::name << " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< " B Type = " << DataTypeTraits<BDataType>::name << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
<< " C Type = " << DataTypeTraits<CDataType>::name
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time; return ave_time;
} }
...@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc, ...@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if(!result) if(!result)
return -1; return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType; using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType; using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType; using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType; using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
...@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc, ...@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
ALayout, BLayout, CLayout>(a_m_k_dev_buf, a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_dev_buf, c_m_n_dev_buf,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
kbatch, kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true; bool pass = true;
...@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc, ...@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
a_m_k, b_k_n, c_m_n_host_ref); a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value = const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType> const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
(K, kbatch, max_accumulated_value); K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result, pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref, c_m_n_host_ref,
"Error: Incorrect results!", "Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<0>{}),
...@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc, ...@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value = const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType> const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
(K, kbatch, max_accumulated_value); K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result, pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref, c_m_n_gpu_ref,
"Error: Incorrect results!", "Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<0>{}),
......
...@@ -137,6 +137,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -137,6 +137,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
// if (threadIdx.x == 0){
// HotLoopScheduler::print();
// }
// Block GEMM // Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
...@@ -532,7 +535,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -532,7 +535,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// Hot loop // Hot loop
while(i_total_loops < (num_total_loop - 1)) while(i_total_loops < (num_total_loop - 1))
{ {
// STAGE 1, Q@K Gemm0 // STAGE 1, Q@K Gemm0
d_block_tile = load_tile(d_dram_window); d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0}); move_tile_window(d_dram_window, {kM0});
...@@ -664,7 +667,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -664,7 +667,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// decltype(p_gemm)>(pt_reg_tensor, p_gemm); // decltype(p_gemm)>(pt_reg_tensor, p_gemm);
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer(); pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
auto qt_reg_tensor = load_tile(qt_lds_read_window); auto qt_reg_tensor = load_tile(qt_lds_read_window);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>(); HotLoopScheduler::template GemmStagedScheduler<1>();
......
...@@ -202,9 +202,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -202,9 +202,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) constexpr index_t kVecLoad =
? kMaxVecLoad ((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : kMinVecLoad;
: kMinVecLoad;
return kVecLoad; return kVecLoad;
} }
...@@ -260,9 +259,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -260,9 +259,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) constexpr index_t kVecLoad =
? kMaxVecLoad ((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : kMinVecLoad;
: kMinVecLoad;
return kVecLoad; return kVecLoad;
} }
...@@ -607,7 +605,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -607,7 +605,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{ {
return GetAlignmentQ<Problem>(); using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType);
} }
template <typename Problem> template <typename Problem>
...@@ -649,7 +648,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -649,7 +648,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad()
{ {
return GetAlignmentOGrad<Problem>(); using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
return 16 / sizeof(OGradDataType);
} }
template <typename Problem> template <typename Problem>
...@@ -666,48 +666,73 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -666,48 +666,73 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return 16 / sizeof(GemmDataType); return 16 / sizeof(GemmDataType);
} }
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack> template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, bool XorLdsLayout = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{ {
constexpr auto DataTypeSize = 2; // sizeof(F16/BF16) if constexpr(XorLdsLayout)
constexpr auto MNLdsLayer = {
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
constexpr auto MNLdsLayer =
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
number<MNPerBlock / MNLdsLayer>{}, constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
number<KPack>{}), make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}), number<MNPerBlock / MNLdsLayer>{},
number<KPack>{}, number<KPack>{}),
number<1>{}); make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
number<KPack>{},
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor( number<1>{});
x_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{}, constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
number<KPerBlock / KPack * MNLdsLayer>{})), x_lds_block_desc_0,
make_pass_through_transform(number<KPack>{})), make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
make_tuple(sequence<1, 0>{}, sequence<2>{}), number<KPerBlock / KPack * MNLdsLayer>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{})); make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( make_tuple(sequence<1, 0>{}, sequence<2>{}));
x_lds_block_desc_permuted,
make_tuple(make_unmerge_transform( constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})), x_lds_block_desc_permuted,
make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}), make_tuple(make_unmerge_transform(
make_pass_through_transform(number<KPack>{})), make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
constexpr auto x_lds_block_desc = transform_tensor_descriptor( make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
x_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod( constexpr auto x_lds_block_desc = transform_tensor_descriptor(
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})), x_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_merge_transform_v3_division_mod( make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))), make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), make_merge_transform_v3_division_mod(
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
return x_lds_block_desc; make_tuple(sequence<0>{}, sequence<1>{}));
return x_lds_block_desc;
}
else
{
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<MNPerBlock>{},
number<KPerBlock / 64>{},
number<64 / KPack>{},
number<KPack>{}),
make_tuple(number<KPerBlock / 64 * (64 / KPack + 1) * KPack>{},
number<(64 / KPack + 1) * KPack>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
x_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<MNPerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(
number<KPerBlock / 64>{}, number<64 / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
} }
template <typename Problem, template <typename Problem,
...@@ -986,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -986,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>(); constexpr index_t KPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>(); return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>();
} }
template <typename Problem> template <typename Problem>
...@@ -1193,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1193,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>(); constexpr index_t KPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>(); return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>();
} }
template <typename Problem> template <typename Problem>
...@@ -1681,14 +1706,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1681,14 +1706,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST; constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST; constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST); constexpr index_t LDS_READ_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) { static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) { static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * MFMA_PER_VMEM_READ + j<LDS_READ_INST){ if constexpr(i * MFMA_PER_VMEM_READ + j < LDS_READ_INST)
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read {
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_PER_MFMA, 0); // DS read
} }
}); });
}); });
...@@ -1709,11 +1737,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1709,11 +1737,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm1MFMA; constexpr index_t MFMA_INST = Gemm1MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST); constexpr index_t LDS_READ_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
static_for<0, MFMA_INST, 1>{}([&](auto i) { static_for<0, MFMA_INST, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i <LDS_READ_INST){ if constexpr(i < LDS_READ_INST)
{
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
} }
}); });
...@@ -1729,11 +1759,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1729,11 +1759,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm2MFMA; constexpr index_t MFMA_INST = Gemm2MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST); constexpr index_t LDS_WRITE_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
static_for<0, MFMA_INST, 1>{}([&](auto i) { static_for<0, MFMA_INST, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i < LDS_WRITE_INST){ if constexpr(i < LDS_WRITE_INST)
{
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
} }
}); });
...@@ -1749,31 +1781,43 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1749,31 +1781,43 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm3MFMA; constexpr index_t MFMA_INST = Gemm3MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST); constexpr index_t LDS_WRITE_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA; constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE)); constexpr index_t LDS_READ_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE));
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST){ if constexpr(i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST)
if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){ {
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write if constexpr((i + 1) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST)
{
__builtin_amdgcn_sched_group_barrier(
0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write
} }
else{ else
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write {
__builtin_amdgcn_sched_group_barrier(
0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
} }
} }
}); });
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){ if constexpr(i * LDS_READ_PER_MFMA < LDS_READ_INST)
if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){ {
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read if constexpr((i + 1) * LDS_READ_PER_MFMA > LDS_READ_INST)
{
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
} }
else{ else
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read {
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_PER_MFMA, 0); // DS Read
} }
} }
}); });
...@@ -1788,21 +1832,42 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1788,21 +1832,42 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm4MFMA; constexpr index_t MFMA_INST = Gemm4MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST); constexpr index_t LDS_READ_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
static_for<0, MFMA_INST, 1>{}([&](auto i) { static_for<0, MFMA_INST, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){ if constexpr(i * LDS_READ_PER_MFMA < LDS_READ_INST)
if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){ {
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read if constexpr((i + 1) * LDS_READ_PER_MFMA > LDS_READ_INST)
{
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
} }
else{ else
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read {
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_PER_MFMA, 0); // DS Read
} }
} }
}); });
} }
CK_TILE_HOST_DEVICE static void print()
{
printf("LDS instruction{");
//
printf("OGradT_LDS_READ: %d, ", OGradT_LDS_READ);
printf("OGrad_LDS_READ: %d, ", OGrad_LDS_READ);
printf("QT_LDS_READ: %d, ", QT_LDS_READ);
printf("Q_LDS_READ: %d, ", Q_LDS_READ);
printf("SGradT_LDS_READ_P1: %d, ", SGradT_LDS_READ_P1);
printf("SGradT_LDS_READ_P2: %d, ", SGradT_LDS_READ_P2);
printf("LSE_LDS_READ: %d, ", LSE_LDS_READ);
printf("D_LDS_READ: %d, ", D_LDS_READ);
printf("}");
}
private: private:
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0; static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
...@@ -1818,6 +1883,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1818,6 +1883,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t WarpGemmN = static constexpr index_t WarpGemmN =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}); Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8; static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8;
static constexpr index_t Gemm0MWarp =
Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
static constexpr index_t Gemm2MWarp =
Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
static constexpr index_t Gemm4MWarp = static constexpr index_t Gemm4MWarp =
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
static constexpr index_t Gemm4NWarp = static constexpr index_t Gemm4NWarp =
...@@ -1847,20 +1916,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1847,20 +1916,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t D_VMEM_READ = 1; static constexpr index_t D_VMEM_READ = 1;
// LDS Read // LDS Read
// 16 * 128 / 64 / 4 = 8
static constexpr index_t OGradT_LDS_READ = static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>(); kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
// 16 * 128 / 64 / 4 = 8
static constexpr index_t QT_LDS_READ = static constexpr index_t QT_LDS_READ =
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>(); kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
// 16 * 32 / 64 / 8 = 1
static constexpr index_t SGradT_LDS_READ_P1 = static constexpr index_t SGradT_LDS_READ_P1 =
// kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); // kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2; kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2;
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>(); // 16 * 128 / 64 / 8 = 4
static constexpr index_t Q_LDS_READ =
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetAlignmentQ<Problem>();
// 1
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// 16 * 96 / 64 / 8 = 3
static constexpr index_t SGradT_LDS_READ_P2 = static constexpr index_t SGradT_LDS_READ_P2 =
// kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); // kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2; kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2;
// 16 * 128 / 64 / 8 = 4
static constexpr index_t OGrad_LDS_READ = static constexpr index_t OGrad_LDS_READ =
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>(); kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetAlignmentOGrad<Problem>();
// 1
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write // LDS Write
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment