Commit 382ac606 authored by Anthony Chang's avatar Anthony Chang
Browse files

avoid LDS data hazard in gemm_softmax_gemm pipeline

parent 5ee30459
......@@ -717,7 +717,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
block_sync_lds();
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
......@@ -736,12 +735,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for gemm0 LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
......@@ -749,6 +749,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
......@@ -817,6 +818,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max = running_max_new;
running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle C and write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment