Commit 99436cd4 authored by danyao12's avatar danyao12
Browse files

save clear_tile

parent b3100b6f
...@@ -537,7 +537,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -537,7 +537,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{ {
// STAGE 1, Q@K Gemm0 // STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{}; auto st_acc = SPTBlockTileType{};
clear_tile(st_acc);
q_block_tile = load_tile(q_dram_window); q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0}); move_tile_window(q_dram_window, {kM0, 0});
...@@ -551,7 +550,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -551,7 +550,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
d_block_tile = load_tile(d_dram_window); d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0}); move_tile_window(d_dram_window, {kM0});
gemm_0(st_acc, q_reg_tensor, k_reg_tensor); st_acc = gemm_0(q_reg_tensor, k_reg_tensor);
auto dot_reg_tensor = load_tile(dot_lds_read_window); auto dot_reg_tensor = load_tile(dot_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<0>(); HotLoopScheduler::template GemmStagedScheduler<0>();
...@@ -670,9 +670,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -670,9 +670,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{}; auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc);
gemm_2(dpt_acc, do_reg_tensor, v_reg_tensor); dpt_acc = gemm_2(do_reg_tensor, v_reg_tensor);
block_sync_lds(); block_sync_lds();
...@@ -804,10 +803,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -804,10 +803,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// Tail // Tail
auto st_acc = SPTBlockTileType{}; auto st_acc = SPTBlockTileType{};
clear_tile(st_acc);
// STAGE 1, Q@K Gemm0 // STAGE 1, Q@K Gemm0
gemm_0(st_acc, q_reg_tensor, k_reg_tensor); st_acc = gemm_0(q_reg_tensor, k_reg_tensor);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
...@@ -919,10 +917,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -919,10 +917,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{}; auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc);
auto qt_reg_tensor = load_tile(qt_lds_read_window); auto qt_reg_tensor = load_tile(qt_lds_read_window);
gemm_2(dpt_acc, do_reg_tensor, v_reg_tensor);
dpt_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>(); HotLoopScheduler::template GemmStagedScheduler<2>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment