"examples/community/imagic_stable_diffusion.py" did not exist on "b2e2d1411ce394ef15c41aafb34b3c08beedff0f"
Commit 3d5b0755 authored by danyao12's avatar danyao12
Browse files

non-iglp pipeline for headdim padding cases

parent f8b14618
...@@ -14,11 +14,13 @@ from codegen.cpp_symbol_map import * ...@@ -14,11 +14,13 @@ from codegen.cpp_symbol_map import *
BWD_DQDKDV_PIPELINE_MAP = { BWD_DQDKDV_PIPELINE_MAP = {
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR", "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP",
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR",
} }
BWD_DQDKDV_PIPELINE_ENUM_MAP = { BWD_DQDKDV_PIPELINE_ENUM_MAP = {
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR", "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP",
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR",
} }
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
...@@ -408,7 +410,7 @@ class FmhaBwdDQDKDVKernel: ...@@ -408,7 +410,7 @@ class FmhaBwdDQDKDVKernel:
if n != '' : n = 'p' + n if n != '' : n = 'p' + n
return n return n
pn = pad_name() pn = pad_name()
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
if pn != '' : n += f'_{pn}' if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}' if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_dbias == 't' : n += '_dbias' if self.F_dbias == 't' : n += '_dbias'
...@@ -450,13 +452,13 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict ...@@ -450,13 +452,13 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr"], "kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr"], "kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr"], "kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr"] "kr_ktr_vr_iglp", "kr_ktr_vr"]
} }
else: else:
return None return None
...@@ -481,6 +483,8 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -481,6 +483,8 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue continue
if ("wg32" in dropout): if ("wg32" in dropout):
continue continue
if (dpad == "t" or dvpad == "t"):
ppl = d[hdim_str][2]
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
...@@ -497,8 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -497,8 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if receipt == 3: if receipt == 3:
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi'] cond &= bias in ['no', 'alibi']
cond &= dpad == "f" cond &= dpad == dvpad
cond &= dvpad == "f"
cond &= deterministic == "f" cond &= deterministic == "f"
if not cond: if not cond:
continue continue
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
......
...@@ -72,9 +72,12 @@ struct FmhaBwdDQDKDVKernel ...@@ -72,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape; using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps; using gbr0 = typename bfs::Gemm0BlockWarps;
using gwt = typename bfs::Gemm0WarpTile; using gbr1 = typename bfs::Gemm1BlockWarps;
using gbr4 = typename bfs::Gemm4BlockWarps;
using gwt0 = typename bfs::Gemm0WarpTile;
using gwt1 = typename bfs::Gemm1WarpTile;
#define _SS_ std::string #define _SS_ std::string
#define _TS_ std::to_string #define _TS_ std::to_string
auto pn = [&] () { auto pn = [&] () {
...@@ -87,10 +90,13 @@ struct FmhaBwdDQDKDVKernel ...@@ -87,10 +90,13 @@ struct FmhaBwdDQDKDVKernel
return return
_SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + "_" + (kIsGroupMode ? "group" : "batch") + "_" +
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
_TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + _TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) +
......
...@@ -488,73 +488,37 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -488,73 +488,37 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static_assert(kM0 == kK3, "kM0 should equal to kK3"); static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
/*
* Prefetch Q, LSE, dO, D
*/
auto q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
auto lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
auto do_block_tile = load_tile(do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
auto d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
/*
* Store prefetched data into LDS
*/
store_tile(q_lds_window, q_block_tile);
shuffle_tile(qt_block_tile, q_block_tile);
store_tile(qt_lds_write_window, qt_block_tile);
store_tile(lse_lds_write_window, lse_block_tile);
store_tile(do_lds_window, do_block_tile);
shuffle_tile(dot_block_tile, do_block_tile);
store_tile(dot_lds_write_window, dot_block_tile);
store_tile(d_lds_write_window, d_block_tile);
block_sync_lds();
/*
* Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
auto q_reg_tensor = load_tile(q_lds_read_window);
auto lse = load_tile(lse_lds_read_window);
auto do_reg_tensor = load_tile(do_lds_read_window);
auto d = load_tile(d_lds_read_window);
clear_tile(dv_acc); clear_tile(dv_acc);
clear_tile(dk_acc); clear_tile(dk_acc);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// Hot loop // Hot loop
while(i_total_loops < (num_total_loop - 1)) while(i_total_loops < num_total_loop)
{ {
// STAGE 1, Q@K Gemm0 auto q_block_tile = load_tile(q_dram_window);
auto st_acc = SPTBlockTileType{};
q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0}); move_tile_window(q_dram_window, {kM0, 0});
lse_block_tile = load_tile(lse_dram_window); auto lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0}); move_tile_window(lse_dram_window, {kM0});
do_block_tile = load_tile(do_dram_window); store_tile(q_lds_window, q_block_tile);
move_tile_window(do_dram_window, {kM0, 0}); shuffle_tile(qt_block_tile, q_block_tile);
store_tile(qt_lds_write_window, qt_block_tile);
d_block_tile = load_tile(d_dram_window); store_tile(lse_lds_write_window, lse_block_tile);
move_tile_window(d_dram_window, {kM0});
st_acc = gemm_0(q_reg_tensor, k_reg_tensor); block_sync_lds();
auto dot_reg_tensor = load_tile(dot_lds_read_window); auto q_reg_tensor = load_tile(q_lds_read_window);
auto lse = load_tile(lse_lds_read_window);
block_sync_lds();
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
st_acc = gemm_0(q_reg_tensor, k_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<0>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -660,36 +624,38 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -660,36 +624,38 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}(); }();
// STAGE 3, P^T@OGrad^T Gemm1 // STAGE 3, P^T@OGrad^T Gemm1
Policy::template PTFromGemm0CToGemm1A<Problem, auto do_block_tile = load_tile(do_dram_window);
decltype(pt_reg_tensor), move_tile_window(do_dram_window, {kM0, 0});
decltype(pt_gemm)>(pt_reg_tensor, pt_gemm);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
auto qt_reg_tensor = load_tile(qt_lds_read_window); auto d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
HotLoopScheduler::template GemmStagedScheduler<1>(); store_tile(do_lds_window, do_block_tile);
__builtin_amdgcn_sched_barrier(0); shuffle_tile(dot_block_tile, do_block_tile);
// STAGE 4, OGrad@V Gemm2 store_tile(dot_lds_write_window, dot_block_tile);
auto dpt_acc = SPGradTBlockTileType{};
dpt_acc = gemm_2(do_reg_tensor, v_reg_tensor); store_tile(d_lds_write_window, d_block_tile);
block_sync_lds(); block_sync_lds();
store_tile(q_lds_window, q_block_tile); auto dot_reg_tensor = load_tile(dot_lds_read_window);
shuffle_tile(qt_block_tile, q_block_tile);
store_tile(qt_lds_write_window, qt_block_tile);
store_tile(lse_lds_write_window, lse_block_tile); block_sync_lds();
store_tile(do_lds_window, do_block_tile); Policy::template PTFromGemm0CToGemm1A<Problem,
shuffle_tile(dot_block_tile, do_block_tile); decltype(pt_reg_tensor),
store_tile(dot_lds_write_window, dot_block_tile); decltype(pt_gemm)>(pt_reg_tensor, pt_gemm);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
store_tile(d_lds_write_window, d_block_tile); // STAGE 4, OGrad@V Gemm2
auto do_reg_tensor = load_tile(do_lds_read_window);
auto d = load_tile(d_lds_read_window);
block_sync_lds();
auto dpt_acc = SPGradTBlockTileType{};
dpt_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D) // STAGE 5, P^T(PGrad^T - D)
auto dst = SPGradTBlockTileType{}; auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
...@@ -732,6 +698,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -732,6 +698,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
auto qt_reg_tensor = load_tile(qt_lds_read_window);
block_sync_lds();
const auto dst_gemm = cast_tile<GemmDataType>(dst); const auto dst_gemm = cast_tile<GemmDataType>(dst);
Policy::template SGradTFromGemm2CToGemm3A<Problem, Policy::template SGradTFromGemm2CToGemm3A<Problem,
...@@ -747,11 +716,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -747,11 +716,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
auto ds_reg_tensor = load_tile(ds_lds_read_window); auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){}; auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4}); move_tile_window(ds_lds_read_window, {0, kK4});
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE7 SGrad@K^T Gemm4 // STAGE7 SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{}; auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); clear_tile(dq_acc);
...@@ -773,12 +738,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -773,12 +738,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
} }
}); });
move_tile_window(ds_lds_read_window, {0, -kN0}); move_tile_window(ds_lds_read_window, {0, -kN0});
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<4>();
// QGrad Scale // QGrad Scale
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
{ {
...@@ -802,234 +761,19 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -802,234 +761,19 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
i_total_loops += 1; i_total_loops += 1;
seqlen_q_step += kM0; seqlen_q_step += kM0;
} }
__builtin_amdgcn_sched_barrier(0);
// Tail
auto st_acc = SPTBlockTileType{};
// STAGE 1, Q@K Gemm0
st_acc = gemm_0(q_reg_tensor, k_reg_tensor);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const auto bias_tile = load_tile(bias_dram_window);
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
},
st_acc,
biast_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
st_acc(i_j_idx) *= scale;
position_encoding.update(st_acc(i_j_idx), row, col);
});
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
});
});
if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_step, k_origin.at(number<0>{}), pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(pt_gemm)>(
pt_reg_tensor, pt_gemm);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>();
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
auto qt_reg_tensor = load_tile(qt_lds_read_window);
dpt_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>();
// STAGE 5, P^T(PGrad^T - D)
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) = pt[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dpt_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_window, dbiast_shuffle_tmp);
}
// STAGE 6, SGrad^T@Q^T Gemm3
const auto dst_gemm = cast_tile<GemmDataType>(dst);
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(dst_gemm)>(dst_reg_tensor, dst_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, dst_gemm);
block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>();
// STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(
kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
HotLoopScheduler::template GemmStagedScheduler<4>();
// Results Scale // Results Scale
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc); dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
} }
else else
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
} }
if constexpr(kIsDeterministic)
{
store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
}
return make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
}; };
......
...@@ -8,7 +8,8 @@ namespace ck_tile { ...@@ -8,7 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching // This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum enum class BlockFmhaBwdPipelineEnum
{ {
KRKTRVR = 0, KRKTRVR_IGLP = 0,
KRKTRVR,
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment