Commit ccd2fb13 authored by aska-0096's avatar aska-0096
Browse files

temp save

parent 385ac815
......@@ -117,6 +117,7 @@ target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OP
set(STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS)
list(APPEND STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
list(APPEND STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
list(APPEND STANDALONE_EXAMPLE_FA_BWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
set(STANDALONE_EXAMPLE_FA_BWD "standalone_example_fa_bwd")
......
......@@ -19,7 +19,7 @@
// Convert DQ
using fmha_dtype_0 = FmhaBwdBf16;
using fmha_dtype_0 = FmhaBwdFp16;
using fmha_bwd_convert_dq_trait_0 =
ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
......@@ -43,7 +43,7 @@ using fmha_bwd_convert_dq_kernel_0 =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
FmhaBwdBf16,
FmhaBwdFp16,
false,
false,
false,
......@@ -132,14 +132,14 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdBf16>::AccDataType,
typename FmhaBwdTypeConfig<FmhaBwdBf16>::KGradDataType,
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
typename FmhaBwdTypeConfig<FmhaBwdFp16>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdBf16>::AccDataType,
typename FmhaBwdTypeConfig<FmhaBwdBf16>::VGradDataType,
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
typename FmhaBwdTypeConfig<FmhaBwdFp16>::VGradDataType,
false,
false>>;
......@@ -149,7 +149,7 @@ using fmha_bwd_dq_dk_dv_kernel_0 =
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
FmhaBwdBf16,
FmhaBwdFp16,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
......@@ -201,7 +201,7 @@ using fmha_bwd_dot_do_o_kernel_0 =
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
using dot_do_o_trait_0 =
fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
......@@ -254,11 +254,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){
float r = -1;
if(t.data_type.compare("bf16") == 0 && (t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) {
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdBf16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}
......@@ -345,7 +345,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
}
template <>
auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
auto get_elimit<FmhaBwdFp16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
double rtol = 1e-2;
double atol = 1e-2;
......@@ -806,9 +806,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
// using instance:
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdBf16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
// using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
// using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
// return r;
if(ave_time < 0)
......@@ -1231,7 +1231,7 @@ int main(int argc, char* argv[])
}
else if(data_type == "bf16")
{
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
}
return -3;
......
......@@ -532,21 +532,22 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// Hot loop
while(i_total_loops < (num_total_loop - 1))
{
// STAGE 1, Q@K Gemm0
// STAGE 1, Q@K Gemm0
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
__builtin_amdgcn_sched_barrier(0);
auto s_acc = SPBlockTileType{};
q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
do_block_tile = load_tile(do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
......@@ -658,12 +659,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}();
// STAGE 3, P^T@OGrad^T Gemm1
Policy::template PTFromGemm0CToGemm1A<Problem,
decltype(pt_reg_tensor),
decltype(p_gemm)>(pt_reg_tensor, p_gemm);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
// Policy::template PTFromGemm0CToGemm1A<Problem,
// decltype(pt_reg_tensor),
// decltype(p_gemm)>(pt_reg_tensor, p_gemm);
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
auto qt_reg_tensor = load_tile(qt_lds_read_window);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
......
......@@ -204,7 +204,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
: kMinVecLoad;
return kVecLoad;
}
......@@ -262,7 +262,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
: kMinVecLoad;
return kVecLoad;
}
......@@ -1292,7 +1292,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
// constexpr index_t kNPerBlock = 32;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
......@@ -1673,7 +1672,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
constexpr index_t VMEM_READ_INST =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
// Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
Q_VMEM_READ + OGrad_VMEM_READ;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
constexpr index_t MFMA_INST = Gemm0MFMA;
......@@ -1681,17 +1681,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) {
ignore = j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * MFMA_PER_VMEM_READ + j<LDS_READ_INST){
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
}
});
});
static_for<0, MFMA_Remainder, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
......@@ -1708,12 +1709,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm1MFMA;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i <LDS_READ_INST){
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
}
});
}
......@@ -1727,12 +1729,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm2MFMA;
// To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA = LDS_WRITE_INST / MFMA_INST;
constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i < LDS_WRITE_INST){
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
}
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment