Commit 382ac606 authored by Anthony Chang's avatar Anthony Chang
Browse files

avoid LDS data hazard in gemm_softmax_gemm pipeline

parent 5ee30459
...@@ -717,7 +717,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -717,7 +717,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum; mathext::exp(max - running_max_new) * sum;
block_sync_lds();
// gemm1 // gemm1
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
...@@ -736,12 +735,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -736,12 +735,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step); b1_block_slice_copy_step);
block_sync_lds(); // wait for gemm0 LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// main body // main body
if constexpr(num_gemm1_k_block_inner_loop > 1) if constexpr(num_gemm1_k_block_inner_loop > 1)
{ {
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0), make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
...@@ -749,6 +749,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -749,6 +749,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1, a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds(); block_sync_lds();
...@@ -817,6 +818,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -817,6 +818,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max = running_max_new; running_max = running_max_new;
running_sum = running_sum_new; running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle C and write out // shuffle C and write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment