Commit 48a9c0b6 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Revert the change with regard to defining dynamic buffer ygrad_gid_buf

parent 8c67fac1
......@@ -1629,11 +1629,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
// ygrad dynamic buffer used for calculating y_dot_dy
const auto ygrad_grid_buf1 = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
// ygrad dynamic buffer used for calculating dV = Pdrop^T * dY or dPdrop = dY * V^T
const auto ygrad_grid_buf2 = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
......@@ -2221,7 +2217,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(I0, I0, I0, I0),
y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf1,
ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
ygrad_thread_buf);
......@@ -2505,7 +2501,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf2);
ygrad_grid_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step);
......@@ -2526,7 +2522,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf2);
ygrad_grid_buf);
block_sync_lds();
......@@ -2566,7 +2562,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
// preload data into LDS
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf2);
ygrad_grid_buf);
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step);
......@@ -2583,7 +2579,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
static_for<0, num_ok_block_main_loop - 1, 1>{}([&](auto i) {
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf2);
ygrad_grid_buf);
block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment