Commit 4d18cd84 authored by danyao12's avatar danyao12
Browse files

adjust block_sync_lds to solve read-write conflicts

parent 6cc7d0de
...@@ -2114,6 +2114,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2114,6 +2114,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3], sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3)); sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2125,8 +2126,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2125,8 +2126,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf); gemm2_a_block_buf);
} }
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3, qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf, k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3, Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
......
...@@ -2044,6 +2044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2044,6 +2044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3], sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3)); sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2060,7 +2061,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2060,7 +2061,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1, qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm2::b_block_slice_copy_step); Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1, qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1,
gemm2_b_block_buf); gemm2_b_block_buf);
......
...@@ -2191,6 +2191,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2191,6 +2191,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3], sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3)); sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2202,8 +2203,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2202,8 +2203,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf); gemm2_a_block_buf);
} }
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3, qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf, k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3, Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
......
...@@ -2142,6 +2142,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2142,6 +2142,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3], sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3)); sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2158,7 +2159,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2158,7 +2159,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1, qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm2::b_block_slice_copy_step); Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1, qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1,
gemm2_b_block_buf); gemm2_b_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment