Commit 48a9c0b6 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Revert the change with regard to defining dynamic buffer ygrad_gid_buf

parent 8c67fac1
...@@ -1629,11 +1629,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1629,11 +1629,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize()); p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
// ygrad dynamic buffer used for calculating y_dot_dy const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto ygrad_grid_buf1 = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
// ygrad dynamic buffer used for calculating dV = Pdrop^T * dY or dPdrop = dY * V^T
const auto ygrad_grid_buf2 = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize()); p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
...@@ -2221,7 +2217,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2221,7 +2217,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
y_thread_buf); y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock, yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf1, ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1, y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
ygrad_thread_buf); ygrad_thread_buf);
...@@ -2505,7 +2501,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2505,7 +2501,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// preload data into LDS // preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf2); ygrad_grid_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step); ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step);
...@@ -2526,7 +2522,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2526,7 +2522,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
gemm1_a_thread_buf); gemm1_a_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf2); ygrad_grid_buf);
block_sync_lds(); block_sync_lds();
...@@ -2566,7 +2562,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2566,7 +2562,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
// preload data into LDS // preload data into LDS
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf2); ygrad_grid_buf);
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step); ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step);
...@@ -2583,7 +2579,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2583,7 +2579,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
static_for<0, num_ok_block_main_loop - 1, 1>{}([&](auto i) { static_for<0, num_ok_block_main_loop - 1, 1>{}([&](auto i) {
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf2); ygrad_grid_buf);
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment