Commit 99436cd4 authored by danyao12's avatar danyao12
Browse files

save clear_tile

parent b3100b6f
......@@ -537,7 +537,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
clear_tile(st_acc);
q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
......@@ -551,7 +550,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
gemm_0(st_acc, q_reg_tensor, k_reg_tensor);
st_acc = gemm_0(q_reg_tensor, k_reg_tensor);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<0>();
......@@ -670,9 +670,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
__builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc);
gemm_2(dpt_acc, do_reg_tensor, v_reg_tensor);
dpt_acc = gemm_2(do_reg_tensor, v_reg_tensor);
block_sync_lds();
......@@ -804,10 +803,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// Tail
auto st_acc = SPTBlockTileType{};
clear_tile(st_acc);
// STAGE 1, Q@K Gemm0
gemm_0(st_acc, q_reg_tensor, k_reg_tensor);
st_acc = gemm_0(q_reg_tensor, k_reg_tensor);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
......@@ -919,10 +917,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc);
auto qt_reg_tensor = load_tile(qt_lds_read_window);
gemm_2(dpt_acc, do_reg_tensor, v_reg_tensor);
dpt_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment