Commit 70fca3ed authored by ltqin's avatar ltqin
Browse files

one block dk pass

parent 7af1b43a
......@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1,
false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock
make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
......@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
......
......@@ -1579,6 +1579,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::Scale{rp_dropout});
// dK:
// dV: blockwise gemm
auto k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto k_grad_thread_buf = k_grad_blockwise_gemm.GetCThreadBuffer();
k_grad_thread_buf.Clear();
// dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
KGradGemmTile_N_K_M::MakeQGridDesc_M0_K_M1(q_grid_desc_k0_m_k1);
......@@ -1959,50 +1966,92 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}); // end gemm dV
// // gemm dP
// block_sync_lds();
// // dP = dY * V^T
// // assume size K == size O so HasMainKBlockLoop is the same
// gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
// ygrad_grid_desc_o0_m_o1,
// Gemm0::a_block_desc_ak0_m_ak1, // reuse
// pgrad_gemm_tile_ygrad_blockwise_copy,
// ygrad_grid_buf,
// gemm0_a_block_buf, // reuse
// Gemm0::a_block_slice_copy_step, // reuse
// v_grid_desc_o0_n_o1,
// Gemm0::b_block_desc_bk0_n_bk1, // reuse
// pgrad_gemm_tile_v_blockwise_copy,
// v_grid_buf,
// gemm0_b_block_buf, // reuse
// Gemm0::b_block_slice_copy_step, // reuse
// pgrad_blockwise_gemm,
// pgrad_thread_buf,
// num_o_block_main_loop);
// // dS = P * (dP - Y_dot_dY)
// auto& sgrad_thread_buf = pgrad_thread_buf;
// constexpr auto pgrad_thread_tile_iterator =
// pgrad_blockwise_gemm.MakeCThreadTileIterator();
// constexpr auto pgrad_thread_idx_to_m_n_adaptor =
// pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
// static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
// constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
// constexpr auto m =
// pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// // dS and P has same thread buf layout
// if(s_slash_p_thread_buf[i] >= 0)
// {
// sgrad_thread_buf(i) =
// s_slash_p_thread_buf[i] *
// (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
// }
// else
// {
// sgrad_thread_buf(i) =
// s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
// }
// });
// gemm dP
block_sync_lds();
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
ygrad_grid_desc_o0_m_o1,
Gemm0::a_block_desc_ak0_m_ak1, // reuse
pgrad_gemm_tile_ygrad_blockwise_copy,
ygrad_grid_buf,
gemm0_a_block_buf, // reuse
Gemm0::a_block_slice_copy_step, // reuse
v_grid_desc_o0_n_o1,
Gemm0::b_block_desc_bk0_n_bk1, // reuse
pgrad_gemm_tile_v_blockwise_copy,
v_grid_buf,
gemm0_b_block_buf, // reuse
Gemm0::b_block_slice_copy_step, // reuse
pgrad_blockwise_gemm,
pgrad_thread_buf,
num_o_block_main_loop);
// dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator();
constexpr auto pgrad_thread_idx_to_m_n_adaptor =
pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout
if(s_slash_p_thread_buf[i] >= 0)
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
}
else
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
}
});
// dK = scalar * dS^T * Q
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// load KGrad Gemm B
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
// load KGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range =
make_tuple(sgrad_slice_idx[I2],
sgrad_slice_idx[I2] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range =
make_tuple(sgrad_slice_idx[I3],
sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf);
}
// kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read
k_grad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_block_buf, k_grad_thread_buf);
}); // end gemm dK
// // gemm dQ
// // dQ = scalar * dS * K
......@@ -2069,57 +2118,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// }
// } // end gemm dQ
// // dK = scalar * dS^T * Q
// v_slash_k_grad_thread_buf.Clear();
// static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// // load KGrad Gemm B
// kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
// // load KGrad Gemm A
// const auto sgrad_slice_idx =
// Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
// constexpr auto mwave_range =
// make_tuple(sgrad_slice_idx[I2],
// sgrad_slice_idx[I2] +
// Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
// constexpr auto nwave_range =
// make_tuple(sgrad_slice_idx[I3],
// sgrad_slice_idx[I3] +
// Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
// if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
// {
// kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
// Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(
// sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
// sgrad_thread_buf,
// Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// gemm2_a_block_buf);
// }
// // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic
// buffer
// // sgrad slice window is moved by loop index
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
// Gemm2::b_block_slice_copy_step);
// block_sync_lds(); // sync before write
// kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
// gemm2_b_block_buf);
// block_sync_lds(); // sync before read
// v_slash_k_grad_blockwise_gemm.Run(
// gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
// }); // end gemm dK
// // atomic_add dK
// kgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
// v_slash_k_grad_thread_buf,
// kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// kgrad_grid_buf);
// move slice window
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
......@@ -2131,17 +2129,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// ygrad_grid_desc_m0_o_m1,
// Gemm2::b_block_reset_copy_step); // rewind M
// vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step
// N
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_o0_n_o1,
pgrad_gemm_tile_v_block_reset_copy_step); // rewind O and step N
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1,
Gemm2::b_block_reset_copy_step); // rewind M
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
// q_grid_desc_m0_k_m1,
// Gemm2::b_block_reset_copy_step); // rewind M
// kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step
// N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......@@ -2155,6 +2155,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
v_slash_k_grad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf);
// atomic_add dK
kgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
k_grad_thread_buf,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_grid_buf);
ignore = c_element_op;
ignore = qgrad_grid_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment