Commit 545eec16 authored by aska-0096's avatar aska-0096
Browse files

change q, do lds layout

parent 72428037
......@@ -17,15 +17,12 @@
#include <utility>
#include <vector>
// Convert DQ
using fmha_dtype_0 = FmhaBwdFp16;
using fmha_bwd_convert_dq_trait_0 =
ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
using fmha_bwd_convert_dq_trait_0 = ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
using fmha_bwd_convert_dq_pipeline_problem_0 =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
using fmha_bwd_convert_dq_pipeline_problem_0 = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
/* BlockSize = */ 256,
......@@ -39,15 +36,10 @@ using fmha_bwd_convert_dq_pipeline_problem_0 =
using fmha_bwd_convert_dq_0 =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
using fmha_bwd_convert_dq_kernel_0 =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
FmhaBwdFp16,
false,
false,
false,
false>;
using convert_dq_trait_0 =
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
......@@ -69,8 +61,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
}
// dq_dk_dv
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_tile_0 = ck_tile::sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
......@@ -129,7 +120,8 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_pipeline_0 =
ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
......@@ -143,10 +135,8 @@ using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using fmha_bwd_dq_dk_dv_kernel_0 = ck_tile::
FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0, fmha_bwd_dk_epilogue_0, fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
FmhaBwdFp16,
......@@ -163,8 +153,7 @@ using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
false>;
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
......@@ -182,8 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
}
// dot_do_o
using fmha_bwd_dot_do_o_trait_0 =
ck_tile::TileFmhaBwdOGradDotOTraits<false, false, 2>;
using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits<false, false, 2>;
using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
......@@ -197,11 +185,9 @@ using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipel
using fmha_bwd_dot_do_o_0 =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_0>;
using fmha_bwd_dot_do_o_kernel_0 =
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
using fmha_bwd_dot_do_o_kernel_0 = ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
using dot_do_o_trait_0 =
fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
......@@ -221,7 +207,6 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_0>()
return k_::GetName();
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
......@@ -244,25 +229,53 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); },
[=](const ck_tile::stream_config& s_){ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); },
[=](const ck_tile::stream_config& s_){ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }
);
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", "
<< fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", "
<< fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
return ck_tile::launch_kernel(
s,
[=](const ck_tile::stream_config& s_) {
fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a);
},
[=](const ck_tile::stream_config& s_) {
fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a);
},
[=](const ck_tile::stream_config& s_) {
fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a);
});
}
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s)
{
float r = -1;
if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) {
if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) &&
(t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) &&
(t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) &&
(a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false))
{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
using dq_dk_dv_trait_ =
fmha_bwd_dq_dk_dv_traits_<128,
FmhaBwdFp16,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
ck_tile::SimplifiedGenericAttentionMask<false>,
ck_tile::BlockDropoutBwd<false, true, false>,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
false,
false>;
using convert_dq_trait_ =
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}
else{
else
{
assert("unsupported case\n");
return r;
}
......@@ -807,10 +820,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
// using instance:
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
// using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
// return r;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false,
// ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
// ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>,
// ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; using
// convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false,
// false>; r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a); return r;
if(ave_time < 0)
{
std::cout << ", not supported yet" << std::flush << std::endl;
......
......@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType,
typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
......@@ -57,8 +62,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(
float ave_time =
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
......@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name
<< " A Type = " << DataTypeTraits<ADataType>::name
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
......@@ -133,8 +135,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(a_m_k_dev_buf,
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
......@@ -160,8 +162,8 @@ int run_gemm_example_with_layouts(int argc,
a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
......@@ -218,8 +220,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
......
......@@ -137,6 +137,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// if (threadIdx.x == 0){
// HotLoopScheduler::print();
// }
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
......@@ -665,6 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
auto qt_reg_tensor = load_tile(qt_lds_read_window);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>();
......
......@@ -202,9 +202,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: kMinVecLoad;
constexpr index_t kVecLoad =
((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : kMinVecLoad;
return kVecLoad;
}
......@@ -260,9 +259,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: kMinVecLoad;
constexpr index_t kVecLoad =
((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : kMinVecLoad;
return kVecLoad;
}
......@@ -607,7 +605,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
return GetAlignmentQ<Problem>();
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType);
}
template <typename Problem>
......@@ -649,7 +648,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad()
{
return GetAlignmentOGrad<Problem>();
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
return 16 / sizeof(OGradDataType);
}
template <typename Problem>
......@@ -666,8 +666,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return 16 / sizeof(GemmDataType);
}
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, bool XorLdsLayout = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{
if constexpr(XorLdsLayout)
{
constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
constexpr auto MNLdsLayer =
......@@ -709,6 +711,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return x_lds_block_desc;
}
else
{
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<MNPerBlock>{},
number<KPerBlock / 64>{},
number<64 / KPack>{},
number<KPack>{}),
make_tuple(number<KPerBlock / 64 * (64 / KPack + 1) * KPack>{},
number<(64 / KPack + 1) * KPack>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
x_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<MNPerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(
number<KPerBlock / 64>{}, number<64 / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
template <typename Problem,
index_t MNPerBlock,
......@@ -986,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t KPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>();
}
template <typename Problem>
......@@ -1193,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t KPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>();
}
template <typename Problem>
......@@ -1681,14 +1706,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
constexpr index_t LDS_READ_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * MFMA_PER_VMEM_READ + j<LDS_READ_INST){
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
if constexpr(i * MFMA_PER_VMEM_READ + j < LDS_READ_INST)
{
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_PER_MFMA, 0); // DS read
}
});
});
......@@ -1709,11 +1737,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm1MFMA;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
constexpr index_t LDS_READ_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
static_for<0, MFMA_INST, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i <LDS_READ_INST){
if constexpr(i < LDS_READ_INST)
{
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
}
});
......@@ -1729,11 +1759,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm2MFMA;
// To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
constexpr index_t LDS_WRITE_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
static_for<0, MFMA_INST, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i < LDS_WRITE_INST){
if constexpr(i < LDS_WRITE_INST)
{
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
}
});
......@@ -1749,31 +1781,43 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm3MFMA;
// To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
constexpr index_t LDS_WRITE_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE));
constexpr index_t LDS_READ_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE));
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST){
if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write
if constexpr(i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST)
{
if constexpr((i + 1) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST)
{
__builtin_amdgcn_sched_group_barrier(
0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write
}
else{
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
else
{
__builtin_amdgcn_sched_group_barrier(
0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
}
}
});
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){
if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
if constexpr(i * LDS_READ_PER_MFMA < LDS_READ_INST)
{
if constexpr((i + 1) * LDS_READ_PER_MFMA > LDS_READ_INST)
{
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
}
else{
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
else
{
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_PER_MFMA, 0); // DS Read
}
}
});
......@@ -1788,21 +1832,42 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm4MFMA;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
constexpr index_t LDS_READ_PER_MFMA =
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
static_for<0, MFMA_INST, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){
if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
if constexpr(i * LDS_READ_PER_MFMA < LDS_READ_INST)
{
if constexpr((i + 1) * LDS_READ_PER_MFMA > LDS_READ_INST)
{
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
}
else{
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
else
{
__builtin_amdgcn_sched_group_barrier(
0x100, LDS_READ_PER_MFMA, 0); // DS Read
}
}
});
}
CK_TILE_HOST_DEVICE static void print()
{
printf("LDS instruction{");
//
printf("OGradT_LDS_READ: %d, ", OGradT_LDS_READ);
printf("OGrad_LDS_READ: %d, ", OGrad_LDS_READ);
printf("QT_LDS_READ: %d, ", QT_LDS_READ);
printf("Q_LDS_READ: %d, ", Q_LDS_READ);
printf("SGradT_LDS_READ_P1: %d, ", SGradT_LDS_READ_P1);
printf("SGradT_LDS_READ_P2: %d, ", SGradT_LDS_READ_P2);
printf("LSE_LDS_READ: %d, ", LSE_LDS_READ);
printf("D_LDS_READ: %d, ", D_LDS_READ);
printf("}");
}
private:
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
......@@ -1818,6 +1883,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t WarpGemmN =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8;
static constexpr index_t Gemm0MWarp =
Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
static constexpr index_t Gemm2MWarp =
Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
static constexpr index_t Gemm4MWarp =
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
static constexpr index_t Gemm4NWarp =
......@@ -1847,20 +1916,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t D_VMEM_READ = 1;
// LDS Read
// 16 * 128 / 64 / 4 = 8
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
// 16 * 128 / 64 / 4 = 8
static constexpr index_t QT_LDS_READ =
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
// 16 * 32 / 64 / 8 = 1
static constexpr index_t SGradT_LDS_READ_P1 =
// kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2;
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
// 16 * 128 / 64 / 8 = 4
static constexpr index_t Q_LDS_READ =
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetAlignmentQ<Problem>();
// 1
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// 16 * 96 / 64 / 8 = 3
static constexpr index_t SGradT_LDS_READ_P2 =
// kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2;
// 16 * 128 / 64 / 8 = 4
static constexpr index_t OGrad_LDS_READ =
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetAlignmentOGrad<Problem>();
// 1
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment