Commit 70fca3ed authored by ltqin's avatar ltqin
Browse files

one block dk pass

parent 7af1b43a
...@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, 1,
false>{ false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
...@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1], // MPerXdl
0, // group 0, // group
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
......
...@@ -1579,6 +1579,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1579,6 +1579,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::Scale{rp_dropout});
// dK:
// dV: blockwise gemm
auto k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto k_grad_thread_buf = k_grad_blockwise_gemm.GetCThreadBuffer();
k_grad_thread_buf.Clear();
// dK: transform input and output tensor descriptors // dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 = const auto q_grid_desc_m0_k_m1 =
KGradGemmTile_N_K_M::MakeQGridDesc_M0_K_M1(q_grid_desc_k0_m_k1); KGradGemmTile_N_K_M::MakeQGridDesc_M0_K_M1(q_grid_desc_k0_m_k1);
...@@ -1959,50 +1966,92 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1959,50 +1966,92 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}); // end gemm dV }); // end gemm dV
// // gemm dP // gemm dP
// block_sync_lds(); block_sync_lds();
// // dP = dY * V^T // dP = dY * V^T
// // assume size K == size O so HasMainKBlockLoop is the same // assume size K == size O so HasMainKBlockLoop is the same
// gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>( gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
// ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
// Gemm0::a_block_desc_ak0_m_ak1, // reuse Gemm0::a_block_desc_ak0_m_ak1, // reuse
// pgrad_gemm_tile_ygrad_blockwise_copy, pgrad_gemm_tile_ygrad_blockwise_copy,
// ygrad_grid_buf, ygrad_grid_buf,
// gemm0_a_block_buf, // reuse gemm0_a_block_buf, // reuse
// Gemm0::a_block_slice_copy_step, // reuse Gemm0::a_block_slice_copy_step, // reuse
// v_grid_desc_o0_n_o1, v_grid_desc_o0_n_o1,
// Gemm0::b_block_desc_bk0_n_bk1, // reuse Gemm0::b_block_desc_bk0_n_bk1, // reuse
// pgrad_gemm_tile_v_blockwise_copy, pgrad_gemm_tile_v_blockwise_copy,
// v_grid_buf, v_grid_buf,
// gemm0_b_block_buf, // reuse gemm0_b_block_buf, // reuse
// Gemm0::b_block_slice_copy_step, // reuse Gemm0::b_block_slice_copy_step, // reuse
// pgrad_blockwise_gemm, pgrad_blockwise_gemm,
// pgrad_thread_buf, pgrad_thread_buf,
// num_o_block_main_loop); num_o_block_main_loop);
// // dS = P * (dP - Y_dot_dY) // dS = P * (dP - Y_dot_dY)
// auto& sgrad_thread_buf = pgrad_thread_buf; auto& sgrad_thread_buf = pgrad_thread_buf;
// constexpr auto pgrad_thread_tile_iterator = constexpr auto pgrad_thread_tile_iterator =
// pgrad_blockwise_gemm.MakeCThreadTileIterator(); pgrad_blockwise_gemm.MakeCThreadTileIterator();
// constexpr auto pgrad_thread_idx_to_m_n_adaptor = constexpr auto pgrad_thread_idx_to_m_n_adaptor =
// pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D(); pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
// static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) { static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
// constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i); constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
// constexpr auto m = constexpr auto m =
// pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0]; pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// // dS and P has same thread buf layout // dS and P has same thread buf layout
// if(s_slash_p_thread_buf[i] >= 0) if(s_slash_p_thread_buf[i] >= 0)
// { {
// sgrad_thread_buf(i) = sgrad_thread_buf(i) =
// s_slash_p_thread_buf[i] * s_slash_p_thread_buf[i] *
// (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]); (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
// } }
// else else
// { {
// sgrad_thread_buf(i) = sgrad_thread_buf(i) =
// s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}]; s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
// } }
// }); });
// dK = scalar * dS^T * Q
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// load KGrad Gemm B
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
// load KGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range =
make_tuple(sgrad_slice_idx[I2],
sgrad_slice_idx[I2] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range =
make_tuple(sgrad_slice_idx[I3],
sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf);
}
// kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read
k_grad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_block_buf, k_grad_thread_buf);
}); // end gemm dK
// // gemm dQ // // gemm dQ
// // dQ = scalar * dS * K // // dQ = scalar * dS * K
...@@ -2069,57 +2118,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2069,57 +2118,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// } // }
// } // end gemm dQ // } // end gemm dQ
// // dK = scalar * dS^T * Q
// v_slash_k_grad_thread_buf.Clear();
// static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// // load KGrad Gemm B
// kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
// // load KGrad Gemm A
// const auto sgrad_slice_idx =
// Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
// constexpr auto mwave_range =
// make_tuple(sgrad_slice_idx[I2],
// sgrad_slice_idx[I2] +
// Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
// constexpr auto nwave_range =
// make_tuple(sgrad_slice_idx[I3],
// sgrad_slice_idx[I3] +
// Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
// if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
// {
// kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
// Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(
// sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
// sgrad_thread_buf,
// Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// gemm2_a_block_buf);
// }
// // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic
// buffer
// // sgrad slice window is moved by loop index
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
// Gemm2::b_block_slice_copy_step);
// block_sync_lds(); // sync before write
// kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
// gemm2_b_block_buf);
// block_sync_lds(); // sync before read
// v_slash_k_grad_blockwise_gemm.Run(
// gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
// }); // end gemm dK
// // atomic_add dK
// kgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
// v_slash_k_grad_thread_buf,
// kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// kgrad_grid_buf);
// move slice window // move slice window
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow( s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
...@@ -2131,17 +2129,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2131,17 +2129,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// ygrad_grid_desc_m0_o_m1, // ygrad_grid_desc_m0_o_m1,
// Gemm2::b_block_reset_copy_step); // rewind M // Gemm2::b_block_reset_copy_step); // rewind M
// vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( // vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N // vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step
// N
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_o0_n_o1, v_grid_desc_o0_n_o1,
pgrad_gemm_tile_v_block_reset_copy_step); // rewind O and step N pgrad_gemm_tile_v_block_reset_copy_step); // rewind O and step N
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow( // kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1, // q_grid_desc_m0_k_m1,
Gemm2::b_block_reset_copy_step); // rewind M // Gemm2::b_block_reset_copy_step); // rewind M
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( // kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N // kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step
// N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -2155,6 +2155,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2155,6 +2155,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
v_slash_k_grad_thread_buf, v_slash_k_grad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf); vgrad_grid_buf);
// atomic_add dK
kgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
k_grad_thread_buf,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_grid_buf);
ignore = c_element_op; ignore = c_element_op;
ignore = qgrad_grid_buf; ignore = qgrad_grid_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment